ajaxwin commited on
Commit
7203787
Β·
1 Parent(s): 9c888b7

Task 3 Implemented

Browse files
Files changed (12) hide show
  1. README.md +141 -130
  2. app.py +17 -3
  3. data/data_loader.py +33 -0
  4. demo.py +76 -0
  5. env/schemas.py +3 -4
  6. eval.py +187 -134
  7. inference.py +100 -10
  8. openenv.yaml +57 -64
  9. tasks/task3/__init__.py +4 -30
  10. tasks/task3/environment.py +350 -0
  11. tasks/task3/grader.py +80 -0
  12. validate.py +189 -167
README.md CHANGED
@@ -1,9 +1,9 @@
1
  # Smart Contract Audit RL Environment
2
 
3
  > **OpenEnv-compliant reinforcement learning environment for smart contract security analysis.**
4
- > Train and evaluate agents on real-world Solidity audit tasks β€” the same work professional auditors do every day.
5
 
6
- [![OpenEnv Spec](https://img.shields.io/badge/OpenEnv-1.1-blue)](openenv.yaml)
7
  [![Python 3.11+](https://img.shields.io/badge/python-3.11%2B-brightgreen)](https://python.org)
8
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow)](LICENSE)
9
 
@@ -11,58 +11,57 @@
11
 
12
  ## Motivation
13
 
14
- Smart contract auditing is a $500M+ industry where human auditors painstakingly review Solidity code for security flaws and formally specify function properties. This environment lets agents practice exactly that workflow β€” exploring contract code through targeted queries and submitting findings β€” providing a rigorous, real-world benchmark for code-reasoning agents.
15
-
16
- Data is sourced from **Certora-audited DeFi projects**, giving agents contracts with the same vulnerability patterns found in production exploits.
17
 
18
  ---
19
 
20
- ## Tasks
21
 
22
- | # | Name | Difficulty | Status | Description |
23
- |---|------|------------|--------|-------------|
24
- | 1 | Targeted Vulnerability Detection | Medium | βœ… Active | Find the vulnerable function and name the vulnerability type |
25
  | 2 | Property Discovery | Hard | βœ… Active | Write the natural-language postcondition for a given function |
26
- | 3 | Rule Checker | Easy | ⏳ Placeholder | Identify which function violates a given property |
27
 
28
  ---
29
 
30
  ## Task 1 β€” Targeted Vulnerability Detection *(Medium)*
31
 
32
- **Setup:** Agent is shown a Solidity contract (4–6 functions). One function contains a critical vulnerability.
33
 
34
- **Objective:** Identify the vulnerable function and describe its vulnerability type in 2–3 words.
35
 
36
  ### Actions
37
 
38
  | Action | Params | Reward |
39
  |--------|--------|--------|
40
  | `list_functions` | β€” | βˆ’0.05 |
41
- | `get_function_code` | `function_name` | +0.05 (target) / βˆ’0.10 (other) |
42
- | `get_function_summary` | `function_name` | +0.03 (target) / βˆ’0.05 (other) |
43
  | `get_file_metadata` | β€” | βˆ’0.04 |
44
  | `get_state_variable` | `variable_name` (opt.) | βˆ’0.05 |
45
  | `get_call_graph` | β€” | βˆ’0.08 |
46
- | `submit` | `function_name`, `vulnerability_type` | **+5.0** / +1.0 / βˆ’1.5 |
 
 
47
 
48
- Repeated identical queries: **βˆ’0.40**
49
 
50
- ### Submit scoring (deterministic)
51
- - **1.0** β†’ correct function **+** correct vulnerability keyword β†’ reward +5.0
52
- - **0.5** β†’ correct function, wrong/vague vulnerability type β†’ reward +1.0
53
- - **0.0** β†’ wrong function β†’ reward βˆ’1.5
54
 
55
- ### Vulnerability types in dataset
56
  Reentrancy Β· Missing access control Β· Integer overflow Β· tx.origin authentication Β·
57
- Front-running Β· Timestamp dependence Β· Denial of service (unbounded loop) Β· Unchecked return value
58
 
59
  ---
60
 
61
  ## Task 2 β€” Property Discovery *(Hard)*
62
 
63
- **Setup:** Agent is shown a single Solidity function and must write its natural-language correctness property (postcondition / invariant).
64
 
65
- **Objective:** Write a precise 2–4 sentence property describing what the function guarantees when it succeeds.
66
 
67
  ### Actions
68
 
@@ -74,51 +73,74 @@ Front-running Β· Timestamp dependence Β· Denial of service (unbounded loop) Β· U
74
  | `get_related_functions` | β€” | βˆ’0.06 |
75
  | `get_io` | β€” | βˆ’0.04 |
76
  | `get_similar_rule` | β€” | βˆ’0.20 |
77
- | `submit_property` | `property` (string) | **0.0–5.0** (scored, ONE attempt) |
78
 
79
- Repeated identical queries: **βˆ’0.40**
80
 
81
- ### Submit scoring (keyword-weighted)
82
  ```
83
- score = 0.70 Γ— (key_phrases_matched / total_key_phrases)
84
- + 0.30 Γ— (bonus_phrases_matched / total_bonus_phrases)
85
-
86
- reward = score Γ— 5.0 β†’ range: 0.0 – 5.0
87
  ```
88
 
89
- Matching uses **word-set containment** with synonym expansion (e.g. "caller" matches "msg.sender", "sender", "user"). Phrases don't need to be adjacent β€” all constituent words just need to appear somewhere in the submitted text.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- **One submission per episode** β€” choose carefully.
 
 
92
 
93
- ### Property coverage
94
- 11 functions across 4 contracts with ground-truth properties: SimpleVault (deposit, withdraw, emergencyDrain), TokenSale (buyTokens, setPrice, withdrawETH), DutchAuction (getPrice, bid, finalize), YieldFarm (stake, claimRewards).
95
 
96
  ---
97
 
98
  ## Observation Space
99
 
100
- Every `step()` and `reset()` returns the same `Observation` structure:
101
 
102
  ```json
103
  {
104
- "task_id": "task2_property_discovery",
105
- "contract_name": "YieldFarm",
106
- "contract_description": "A simple yield farming contract...",
107
- "available_actions": ["get_function_code", "get_function_natspec", ...],
108
- "last_action": "get_function_natspec",
109
- "last_action_result": "NatSpec for 'claimRewards':\n@notice Claim all accrued...",
110
- "step_count": 2,
111
- "cumulative_reward": -0.14,
112
  "done": false,
113
  "extra": {
114
- "target_function": "claimRewards",
115
- "target_signature": "claimRewards()",
116
- "solidity_version": "0.8.10",
117
- "hint": "Discover the property of the target function..."
118
  }
119
  }
120
  ```
121
 
 
 
 
122
  ---
123
 
124
  ## Project Structure
@@ -126,54 +148,54 @@ Every `step()` and `reset()` returns the same `Observation` structure:
126
  ```
127
  smart-contract-env/
128
  β”œβ”€β”€ data/
129
- β”‚ β”œβ”€β”€ contracts.json # 4 contracts Β· 8 vulnerabilities Β· 11 properties
130
- β”‚ └── data_loader.py # JSON parser, episode samplers, T1 + T2 helpers
131
  β”œβ”€β”€ env/
132
  β”‚ β”œβ”€β”€ base_env.py # Abstract OpenEnv base class
133
- β”‚ └── schemas.py # Pydantic: Observation, Action, Reward, StepResult…
134
  β”œβ”€β”€ tasks/
135
  β”‚ β”œβ”€β”€ task1/
136
- β”‚ β”‚ β”œβ”€β”€ environment.py # Full Task 1 RL environment
137
- β”‚ β”‚ └── grader.py # Deterministic 0/0.5/1.0 rubric + longest-match keywords
138
  β”‚ β”œβ”€β”€ task2/
139
- β”‚ β”‚ β”œβ”€β”€ environment.py # Full Task 2 RL environment (one submit per episode)
140
- β”‚ β”‚ └── grader.py # Keyword-weighted 0.0–1.0 grader + synonym expansion
141
- β”‚ └── task3/ # TODO: Rule Checker (placeholder)
142
- β”œβ”€β”€ app.py # FastAPI server β€” all OpenEnv HTTP endpoints
143
- β”œβ”€β”€ inference.py # Baseline LLM agent (Task 1 + Task 2)
144
- β”œβ”€β”€ eval.py # Oracle/partial/random evaluation harness
145
- β”œβ”€β”€ demo.py # Colourised interactive + scripted demo
146
- β”œβ”€β”€ validate.py # 19-check pre-submission validator
 
 
147
  β”œβ”€β”€ openenv.yaml # Full OpenEnv spec metadata
148
- β”œβ”€β”€ Dockerfile # Port 7860, uvicorn, healthcheck
149
  └── requirements.txt
150
  ```
151
 
152
  ---
153
 
154
- ## Setup & Usage
155
 
156
  ### Local Python
157
 
158
  ```bash
159
- git clone <repo> && cd smart-contract-env
160
  pip install -r requirements.txt
161
 
162
- # Run the server
163
- python app.py # β†’ http://localhost:7860
164
 
165
- # Run interactive demo
166
- python demo.py # Task 1 interactive
167
- python demo.py --auto # Task 1 scripted
168
- python demo.py --auto --task 2 # Task 2 scripted (add --task flag)
169
 
170
- # Run evaluation harness (no LLM needed)
171
- python eval.py # Both tasks, 8 episodes each
172
- python eval.py --task 2 # Task 2 only
173
  python eval.py --episodes 16 --verbose
174
 
175
  # Pre-submission validation
176
- python validate.py # 19/19 checks
177
  ```
178
 
179
  ### Docker
@@ -186,28 +208,19 @@ docker run -p 7860:7860 sc-audit-env
186
  ### Direct Python API
187
 
188
  ```python
189
- from tasks.task1.environment import Task1Environment
190
- from tasks.task2.environment import Task2Environment
191
  from env.schemas import Action, ActionType
192
 
193
- # Task 1
194
- env = Task1Environment()
195
  r = env.reset(seed=42)
196
- print(r.observation.contract_name) # SimpleVault
197
- s = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
198
- s = env.step(Action(action_type=ActionType.SUBMIT,
199
- params={"function_name": "emergencyDrain",
200
- "vulnerability_type": "missing access control"}))
201
- print(s.reward.value) # +5.0
202
-
203
- # Task 2
204
- env2 = Task2Environment()
205
- r2 = env2.reset(seed=42)
206
- print(r2.observation.extra["target_function"]) # claimRewards
207
- s2 = env2.step(Action(action_type=ActionType.GET_FUNCTION_NATSPEC))
208
- s2 = env2.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
209
- params={"property": "After a successful claimRewards call, all accrued reward tokens are transferred to the caller and their rewards balance is zeroed. Reverts if no rewards."}))
210
- print(s2.reward.value) # ~4.0
211
  ```
212
 
213
  ---
@@ -220,20 +233,23 @@ print(s2.reward.value) # ~4.0
220
  | `GET` | `/tasks` | All tasks + status |
221
  | `POST` | `/reset` | Start episode (`task_id`, `seed`) |
222
  | `POST` | `/step` | Take action (`action_type`, `params`) |
223
- | `GET` | `/state` | Debug: internal episode state |
224
- | `GET` | `/action_space?task_id=...` | Action schema for a task |
225
  | `GET` | `/observation_space` | Observation schema |
226
 
227
  ```bash
228
- # Task 2 full episode
229
  curl -X POST localhost:7860/reset \
230
- -d '{"task_id":"task2_property_discovery","seed":42}'
 
231
 
232
  curl -X POST localhost:7860/step \
233
- -d '{"action_type":"get_function_natspec","params":{}}'
 
234
 
235
  curl -X POST localhost:7860/step \
236
- -d '{"action_type":"submit_property","params":{"property":"..."}}'
 
237
  ```
238
 
239
  ---
@@ -244,45 +260,30 @@ curl -X POST localhost:7860/step \
244
  export API_BASE_URL="https://api.openai.com/v1"
245
  export MODEL_NAME="gpt-4o-mini"
246
  export HF_TOKEN="sk-..."
247
-
248
  python inference.py
249
- # β†’ baseline_scores.json
250
  ```
251
 
252
- ### Expected baseline scores (gpt-4o-mini, 3 episodes per task)
253
 
254
  | Task | Avg Grader Score | Notes |
255
  |------|-----------------|-------|
256
- | Task 1 | ~0.67 | Good at common vulns; misses subtle ones |
257
- | Task 2 | ~0.55 | Reasonable properties but often misses specific variable names |
258
- | Task 3 | 0.00 | Placeholder |
259
 
260
  ---
261
 
262
- ## Evaluation Scores
263
 
264
- Deterministic oracle / partial / baseline tiers verified on 8 episodes (seeds 42–49):
265
 
266
- | Task | Oracle | Partial | Floor |
267
- |------|--------|---------|-------|
268
- | Task 1 | **1.000** | 0.500 | 0.000 |
269
- | Task 2 | **0.775** | 0.034 | 0.000 |
 
270
 
271
- The clear separation confirms the grader provides **meaningful gradient signal** for RL training.
272
-
273
- ---
274
-
275
- ## Deploying to Hugging Face Spaces
276
-
277
- 1. Create a new **Docker** Space at [huggingface.co/spaces](https://huggingface.co/spaces)
278
- 2. Add tag `openenv` in the Space settings
279
- 3. Copy the `SPACES_README.md` frontmatter into `README.md`
280
- 4. Push:
281
-
282
- ```bash
283
- git remote add hf https://huggingface.co/spaces/<user>/<space>
284
- git push hf main
285
- ```
286
 
287
  ---
288
 
@@ -295,15 +296,25 @@ git push hf main
295
  | `reset() β†’ ResetResult` | βœ… |
296
  | `state() β†’ StateResult` | βœ… |
297
  | `openenv.yaml` metadata | βœ… |
298
- | 3+ tasks defined | βœ… (2 active, 1 placeholder) |
299
  | Grader scores in [0.0, 1.0] | βœ… |
300
- | Shaped rewards (non-binary) | βœ… |
301
  | Dockerfile + port 7860 | βœ… |
302
  | `inference.py` with OpenAI client | βœ… |
303
- | `validate.py` β€” all 19 checks pass | βœ… |
 
 
 
 
 
 
 
 
 
 
304
 
305
  ---
306
 
307
  ## License
308
 
309
- MIT. Contract vulnerability data adapted from Certora audits on production DeFi protocols.
 
1
  # Smart Contract Audit RL Environment
2
 
3
  > **OpenEnv-compliant reinforcement learning environment for smart contract security analysis.**
4
+ > Three fully implemented tasks covering the core workflow of a professional Solidity auditor.
5
 
6
+ [![OpenEnv Spec](https://img.shields.io/badge/OpenEnv-1.2-blue)](openenv.yaml)
7
  [![Python 3.11+](https://img.shields.io/badge/python-3.11%2B-brightgreen)](https://python.org)
8
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow)](LICENSE)
9
 
 
11
 
12
  ## Motivation
13
 
14
+ Smart contract auditing is a $500M+ industry where human experts identify security flaws, write formal properties, and check whether code satisfies those properties. This environment lets agents practise exactly those three tasks using real Solidity contracts from Certora-audited DeFi projects.
 
 
15
 
16
  ---
17
 
18
+ ## Tasks at a Glance
19
 
20
+ | # | Name | Difficulty | Status | One-line description |
21
+ |---|------|-----------|--------|---------------------|
22
+ | 1 | Targeted Vulnerability Detection | Medium | βœ… Active | Find which function is vulnerable and name the vulnerability |
23
  | 2 | Property Discovery | Hard | βœ… Active | Write the natural-language postcondition for a given function |
24
+ | 3 | Rule Checker | Easy | βœ… Active | Identify which function violates a given property |
25
 
26
  ---
27
 
28
  ## Task 1 β€” Targeted Vulnerability Detection *(Medium)*
29
 
30
+ **Setup:** A Solidity contract (4–6 functions) is shown. One function contains a critical vulnerability.
31
 
32
+ **Objective:** Name the vulnerable function and describe its vulnerability type in 2–3 words.
33
 
34
  ### Actions
35
 
36
  | Action | Params | Reward |
37
  |--------|--------|--------|
38
  | `list_functions` | β€” | βˆ’0.05 |
39
+ | `get_function_code` | `function_name` | +0.05 if target / βˆ’0.10 if other |
40
+ | `get_function_summary` | `function_name` | +0.03 if target / βˆ’0.05 if other |
41
  | `get_file_metadata` | β€” | βˆ’0.04 |
42
  | `get_state_variable` | `variable_name` (opt.) | βˆ’0.05 |
43
  | `get_call_graph` | β€” | βˆ’0.08 |
44
+ | `submit` | `function_name`, `vulnerability_type` | **+5.0 / +1.0 / βˆ’1.5** |
45
+
46
+ Repeated queries: **βˆ’0.40**
47
 
48
+ ### Grader
49
 
50
+ - **1.0** β†’ correct function + correct vulnerability keyword β†’ reward **+5.0**
51
+ - **0.5** β†’ correct function, vague/wrong vulnerability type β†’ reward **+1.0**
52
+ - **0.0** β†’ wrong function β†’ reward **βˆ’1.5**
 
53
 
54
+ ### Vulnerability types covered
55
  Reentrancy Β· Missing access control Β· Integer overflow Β· tx.origin authentication Β·
56
+ Front-running Β· Timestamp dependence Β· Denial of service Β· Unchecked return value
57
 
58
  ---
59
 
60
  ## Task 2 β€” Property Discovery *(Hard)*
61
 
62
+ **Setup:** A single Solidity function is shown. The agent must discover its natural-language correctness property.
63
 
64
+ **Objective:** Write a precise 2–4 sentence postcondition describing what the function guarantees on success.
65
 
66
  ### Actions
67
 
 
73
  | `get_related_functions` | β€” | βˆ’0.06 |
74
  | `get_io` | β€” | βˆ’0.04 |
75
  | `get_similar_rule` | β€” | βˆ’0.20 |
76
+ | `submit_property` | `property` (string) | **0.0–5.0** scored, ONE attempt |
77
 
78
+ ### Grader (keyword-weighted)
79
 
 
80
  ```
81
+ score = 0.70 Γ— (key_phrases_matched / total_key)
82
+ + 0.30 Γ— (bonus_phrases_matched / total_bonus)
83
+ reward = score Γ— 5.0
 
84
  ```
85
 
86
+ Matching uses **word-set containment + synonym expansion** β€” words don't need to be adjacent.
87
+
88
+ ---
89
+
90
+ ## Task 3 β€” Rule Checker *(Easy)*
91
+
92
+ **Setup:** A Solidity contract is shown alongside a violated property in natural English. One function breaks that property.
93
+
94
+ **Objective:** Identify which function violates the property.
95
+
96
+ ### Actions
97
+
98
+ | Action | Params | Reward |
99
+ |--------|--------|--------|
100
+ | `list_functions` | β€” | βˆ’0.05 |
101
+ | `get_function_metadata` | `function_name` | βˆ’0.05 |
102
+ | `get_function_code` | `function_name` | βˆ’0.10 |
103
+ | `get_state_variable` | `variable_name` (opt.) | βˆ’0.05 |
104
+ | `get_call_graph` | β€” | βˆ’0.08 |
105
+ | `get_formalized_property` | β€” | **βˆ’0.03** (cheapest β€” read this first!) |
106
+ | `submit_function` | `function_name` | **+5.0 / +1.5 / βˆ’1.5**, ONE attempt |
107
+
108
+ ### Grader (three-tier deterministic)
109
 
110
+ - **1.0** β†’ exact target function (case-insensitive) β†’ reward **+5.0**
111
+ - **0.3** β†’ a direct internal subfunction of the target β†’ reward **+1.5**
112
+ - **0.0** β†’ anything else β†’ reward **βˆ’1.5**
113
 
114
+ `get_formalized_property` returns the precise pre/post-condition (`rule_broken_specs`). Reading it costs only βˆ’0.03 and usually provides enough information to identify the violating function without inspecting all code.
 
115
 
116
  ---
117
 
118
  ## Observation Space
119
 
120
+ All tasks share the same `Observation` structure:
121
 
122
  ```json
123
  {
124
+ "task_id": "task3_rule_checker",
125
+ "contract_name": "SimpleVault",
126
+ "contract_description": "An ETH vault that allows users to deposit...",
127
+ "available_actions": ["list_functions", "get_function_metadata", "..."],
128
+ "last_action": "get_formalized_property",
129
+ "last_action_result": "Formal property:\nPre: caller != owner...",
130
+ "step_count": 1,
131
+ "cumulative_reward": -0.03,
132
  "done": false,
133
  "extra": {
134
+ "property_english": "Only the owner should be able to drain the vault...",
135
+ "solidity_version": "0.8.0",
136
+ "hint": "Find the function that violates this property..."
 
137
  }
138
  }
139
  ```
140
 
141
+ For Task 2, `extra` contains `target_function` and `target_signature`.
142
+ For Task 3, `extra` contains `property_english`.
143
+
144
  ---
145
 
146
  ## Project Structure
 
148
  ```
149
  smart-contract-env/
150
  β”œβ”€β”€ data/
151
+ β”‚ β”œβ”€β”€ contracts.json # 4 contracts, 8 vulns, 11 properties, 8 rule episodes
152
+ β”‚ └── data_loader.py # loaders for all three tasks
153
  β”œβ”€β”€ env/
154
  β”‚ β”œβ”€β”€ base_env.py # Abstract OpenEnv base class
155
+ β”‚ └── schemas.py # Typed Pydantic models (all ActionTypes)
156
  β”œβ”€β”€ tasks/
157
  β”‚ β”œβ”€β”€ task1/
158
+ β”‚ β”‚ β”œβ”€β”€ environment.py # Vulnerability detection environment
159
+ β”‚ β”‚ └── grader.py # Longest-match keyword grader (0/0.5/1.0)
160
  β”‚ β”œβ”€β”€ task2/
161
+ β”‚ β”‚ β”œβ”€β”€ environment.py # Property discovery (one submit_property)
162
+ β”‚ β”‚ └── grader.py # Word-set + synonym grader (0.0–1.0)
163
+ β”‚ └── task3/
164
+ β”‚ β”œβ”€β”€ environment.py # Rule checker (one submit_function)
165
+ β”‚ └── grader.py # Three-tier grader (1.0/0.3/0.0)
166
+ β”œβ”€β”€ app.py # FastAPI β€” all OpenEnv HTTP endpoints
167
+ β”œβ”€β”€ inference.py # Baseline LLM agent (all 3 tasks)
168
+ β”œβ”€β”€ eval.py # Oracle/partial/floor evaluation harness
169
+ β”œβ”€β”€ demo.py # Colourised scripted demos for all 3 tasks
170
+ β”œβ”€β”€ validate.py # 23-check pre-submission validator
171
  β”œβ”€β”€ openenv.yaml # Full OpenEnv spec metadata
172
+ β”œβ”€β”€ Dockerfile # Port 7860, healthcheck
173
  └── requirements.txt
174
  ```
175
 
176
  ---
177
 
178
+ ## Setup
179
 
180
  ### Local Python
181
 
182
  ```bash
 
183
  pip install -r requirements.txt
184
 
185
+ # Start the server
186
+ python app.py # β†’ http://localhost:7860
187
 
188
+ # Interactive / scripted demos
189
+ python demo.py --auto # Task 1 scripted demo
190
+ python demo.py --auto --seed 42 # Task 2 (same flag, different env seed)
 
191
 
192
+ # Full evaluation harness (no LLM required)
193
+ python eval.py # All 3 tasks, 8 episodes each
194
+ python eval.py --task 3 # Task 3 only
195
  python eval.py --episodes 16 --verbose
196
 
197
  # Pre-submission validation
198
+ python validate.py # 23/23 checks
199
  ```
200
 
201
  ### Docker
 
208
  ### Direct Python API
209
 
210
  ```python
211
+ # Task 3 example
212
+ from tasks.task3.environment import Task3Environment
213
  from env.schemas import Action, ActionType
214
 
215
+ env = Task3Environment()
 
216
  r = env.reset(seed=42)
217
+ print(r.observation.extra["property_english"])
218
+ # "Only the owner should be able to drain the vault..."
219
+
220
+ s = env.step(Action(action_type=ActionType.GET_FORMALIZED_PROPERTY))
221
+ s = env.step(Action(action_type=ActionType.SUBMIT_FUNCTION,
222
+ params={"function_name": "emergencyDrain"}))
223
+ print(s.reward.value) # +5.0
 
 
 
 
 
 
 
 
224
  ```
225
 
226
  ---
 
233
  | `GET` | `/tasks` | All tasks + status |
234
  | `POST` | `/reset` | Start episode (`task_id`, `seed`) |
235
  | `POST` | `/step` | Take action (`action_type`, `params`) |
236
+ | `GET` | `/state` | Internal debug state |
237
+ | `GET` | `/action_space?task_id=...` | Action schema |
238
  | `GET` | `/observation_space` | Observation schema |
239
 
240
  ```bash
241
+ # Full Task 3 episode
242
  curl -X POST localhost:7860/reset \
243
+ -H "Content-Type: application/json" \
244
+ -d '{"task_id":"task3_rule_checker","seed":42}'
245
 
246
  curl -X POST localhost:7860/step \
247
+ -H "Content-Type: application/json" \
248
+ -d '{"action_type":"get_formalized_property","params":{}}'
249
 
250
  curl -X POST localhost:7860/step \
251
+ -H "Content-Type: application/json" \
252
+ -d '{"action_type":"submit_function","params":{"function_name":"emergencyDrain"}}'
253
  ```
254
 
255
  ---
 
260
  export API_BASE_URL="https://api.openai.com/v1"
261
  export MODEL_NAME="gpt-4o-mini"
262
  export HF_TOKEN="sk-..."
 
263
  python inference.py
 
264
  ```
265
 
266
+ ### Expected scores (gpt-4o-mini, 3 episodes per task)
267
 
268
  | Task | Avg Grader Score | Notes |
269
  |------|-----------------|-------|
270
+ | Task 1 | ~0.67 | Good at classic vulns; struggles with subtle ones |
271
+ | Task 2 | ~0.55 | Reasonable properties; misses specific variable names |
272
+ | Task 3 | ~0.78 | Property text gives strong signal; usually correct in 3–4 steps |
273
 
274
  ---
275
 
276
+ ## Evaluation Summary
277
 
278
+ Deterministic oracle / partial / floor tiers verified on 8 episodes (seeds 42–49):
279
 
280
+ | Task | Oracle | Partial/Sub | Floor | Ordering |
281
+ |------|--------|-------------|-------|----------|
282
+ | Task 1 | **1.000** | 0.500 | 0.000 | βœ… 1.0 > 0.5 > 0.0 |
283
+ | Task 2 | **0.775** | 0.034 | 0.000 | βœ… 0.775 > 0.034 > 0.0 |
284
+ | Task 3 | **1.000** | 0.037 | 0.000 | βœ… 1.0 > 0.037 > 0.0 |
285
 
286
+ The clear separation across all three tasks confirms the graders provide **meaningful gradient signal** across the full reward range β€” a core requirement for RL training environments.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  ---
289
 
 
296
  | `reset() β†’ ResetResult` | βœ… |
297
  | `state() β†’ StateResult` | βœ… |
298
  | `openenv.yaml` metadata | βœ… |
299
+ | 3 tasks, all active | βœ… |
300
  | Grader scores in [0.0, 1.0] | βœ… |
301
+ | Shaped rewards (non-binary signal) | βœ… |
302
  | Dockerfile + port 7860 | βœ… |
303
  | `inference.py` with OpenAI client | βœ… |
304
+ | `validate.py` β€” 23/23 checks pass | βœ… |
305
+
306
+ ---
307
+
308
+ ## Deploying to Hugging Face Spaces
309
+
310
+ ```bash
311
+ # Copy the HF frontmatter into README.md, then:
312
+ git remote add hf https://huggingface.co/spaces/<user>/<space>
313
+ git push hf main
314
+ ```
315
 
316
  ---
317
 
318
  ## License
319
 
320
+ MIT. Contract vulnerability patterns adapted from Certora audits on production DeFi protocols.
app.py CHANGED
@@ -24,6 +24,7 @@ from pydantic import BaseModel
24
  from env.schemas import Action, ActionType, TaskInfo
25
  from tasks.task1.environment import Task1Environment
26
  from tasks.task2.environment import Task2Environment
 
27
 
28
  # ─────────────────────────────────────────────────────────────────────────────
29
  # App
@@ -35,7 +36,7 @@ app = FastAPI(
35
  "OpenEnv-compliant reinforcement learning environment for smart contract "
36
  "security analysis. Train and evaluate agents on real-world Solidity audit tasks."
37
  ),
38
- version="1.1.0",
39
  )
40
 
41
  # ─────────────────────────────────────────────────────────────────────────────
@@ -48,7 +49,7 @@ DEFAULT_SESSION = "default"
48
  TASK_ENV_MAP = {
49
  "task1_vuln_detection": Task1Environment,
50
  "task2_property_discovery": Task2Environment,
51
- # TODO: "task3_rule_checker": Task3Environment,
52
  }
53
 
54
 
@@ -109,7 +110,7 @@ def list_tasks():
109
  name="Rule Checker",
110
  difficulty="easy",
111
  description="Given a property in English and a Solidity contract, identify which function violates that property.",
112
- status="placeholder",
113
  ),
114
  ]
115
  return {"tasks": [t.model_dump() for t in tasks]}
@@ -195,6 +196,19 @@ def action_space(task_id: str = "task1_vuln_detection"):
195
  {"type": "submit_property", "params": {"property": "string"}, "reward": "0.0–5.0 (scored)", "description": "Submit property. ONE attempt. Ends episode."},
196
  ],
197
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  return {"error": f"No action space defined for task '{task_id}'"}
199
 
200
 
 
24
  from env.schemas import Action, ActionType, TaskInfo
25
  from tasks.task1.environment import Task1Environment
26
  from tasks.task2.environment import Task2Environment
27
+ from tasks.task3.environment import Task3Environment
28
 
29
  # ─────────────────────────────────────────────────────────────────────────────
30
  # App
 
36
  "OpenEnv-compliant reinforcement learning environment for smart contract "
37
  "security analysis. Train and evaluate agents on real-world Solidity audit tasks."
38
  ),
39
+ version="1.2.0",
40
  )
41
 
42
  # ─────────────────────────────────────────────────────────────────────────────
 
49
  TASK_ENV_MAP = {
50
  "task1_vuln_detection": Task1Environment,
51
  "task2_property_discovery": Task2Environment,
52
+ "task3_rule_checker": Task3Environment,
53
  }
54
 
55
 
 
110
  name="Rule Checker",
111
  difficulty="easy",
112
  description="Given a property in English and a Solidity contract, identify which function violates that property.",
113
+ status="active",
114
  ),
115
  ]
116
  return {"tasks": [t.model_dump() for t in tasks]}
 
196
  {"type": "submit_property", "params": {"property": "string"}, "reward": "0.0–5.0 (scored)", "description": "Submit property. ONE attempt. Ends episode."},
197
  ],
198
  }
199
+ if task_id == "task3_rule_checker":
200
+ return {
201
+ "task_id": task_id,
202
+ "actions": [
203
+ {"type": "list_functions", "params": {}, "reward": -0.05, "description": "List all function names"},
204
+ {"type": "get_function_metadata", "params": {"function_name": "string"}, "reward": -0.05, "description": "Get signature, visibility, params of a function"},
205
+ {"type": "get_function_code", "params": {"function_name": "string"}, "reward": -0.10, "description": "Read full Solidity source of a function"},
206
+ {"type": "get_state_variable", "params": {"variable_name": "string (opt)"}, "reward": -0.05, "description": "Get a state variable or list all"},
207
+ {"type": "get_call_graph", "params": {}, "reward": -0.08, "description": "Get function call graph"},
208
+ {"type": "get_formalized_property", "params": {}, "reward": -0.03, "description": "Get formal pre/post-condition for the property"},
209
+ {"type": "submit_function", "params": {"function_name": "string"}, "reward": "+5.0 / +1.5 / -1.5", "description": "Submit answer. ONE attempt. Ends episode."},
210
+ ],
211
+ }
212
  return {"error": f"No action space defined for task '{task_id}'"}
213
 
214
 
data/data_loader.py CHANGED
@@ -193,3 +193,36 @@ def get_similar_rule(
193
  "natspec": "",
194
  }
195
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  "natspec": "",
194
  }
195
  return None
196
+
197
+
198
+ # ────────────────────────────────────────────────────────────────
199
+ # Task 3 helpers
200
+ # ────────────────────────────────────────────────────────────────
201
+
202
+ def get_all_task3_entries(
203
+ contracts: List[Dict[str, Any]],
204
+ ) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
205
+ """
206
+ Returns (contract, function) pairs where function has a task3 field
207
+ with a non-empty property_english. These are the episode pool for Task 3.
208
+ """
209
+ entries = []
210
+ for contract in contracts:
211
+ for fn in contract.get("functions", []):
212
+ t3 = fn.get("task3", {})
213
+ if t3.get("property_english"):
214
+ entries.append((contract, fn))
215
+ return entries
216
+
217
+
218
+ def sample_task3_episode(
219
+ contracts: List[Dict[str, Any]],
220
+ rng: Optional[random.Random] = None,
221
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
222
+ """Randomly selects one (contract, vulnerable_function) pair for Task 3."""
223
+ if rng is None:
224
+ rng = random.Random()
225
+ entries = get_all_task3_entries(contracts)
226
+ if not entries:
227
+ raise ValueError("No Task 3 entries found in dataset.")
228
+ return rng.choice(entries)
demo.py CHANGED
@@ -355,3 +355,79 @@ def run_auto_demo_t2(seed: int = 42, delay: float = 0.9):
355
  if step_result.done:
356
  _print_episode_summary(sobs)
357
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  if step_result.done:
356
  _print_episode_summary(sobs)
357
  return
358
+
359
+
360
+ # ─────────────────────────────────────────────────────────────────────────────
361
+ # Task 3 demo
362
+ # ─────────────────────────────────────────────────────────────────────────────
363
+
364
+ DEMO_SCRIPTS_T3 = {
365
+ 42: [
366
+ (ActionType.GET_FORMALIZED_PROPERTY, {},
367
+ "Read the formal spec first β€” cheapest action at -0.03."),
368
+ (ActionType.LIST_FUNCTIONS, {},
369
+ "List all functions to survey candidates."),
370
+ (ActionType.GET_FUNCTION_CODE, {"function_name": "emergencyDrain"},
371
+ "No access modifier! Anyone can call this β€” that's the violation."),
372
+ (ActionType.SUBMIT_FUNCTION, {"function_name": "emergencyDrain"},
373
+ "Confident. emergencyDrain violates the access-control property."),
374
+ ],
375
+ 45: [
376
+ (ActionType.GET_FORMALIZED_PROPERTY, {},
377
+ "Formal spec: first caller at valid price should win."),
378
+ (ActionType.LIST_FUNCTIONS, {},
379
+ "Auction contract β€” bid() immediately looks suspicious."),
380
+ (ActionType.GET_FUNCTION_CODE, {"function_name": "bid"},
381
+ "No commit-reveal, no maxPrice guard β€” front-running is trivially possible."),
382
+ (ActionType.SUBMIT_FUNCTION, {"function_name": "bid"},
383
+ "bid() violates the front-running property. Submitting."),
384
+ ],
385
+ }
386
+
387
+
388
+ def run_auto_demo_t3(seed: int = 42, delay: float = 0.9):
389
+ """Run the scripted Task 3 demo."""
390
+ from tasks.task3.environment import Task3Environment
391
+
392
+ script = DEMO_SCRIPTS_T3.get(seed)
393
+ env = Task3Environment()
394
+ result = env.reset(seed=seed)
395
+ obs = result.observation
396
+
397
+ print()
398
+ print(f"{BOLD}{CYAN}╔══════════════════════════════════════════════════════════╗")
399
+ print(f"β•‘ Smart Contract Audit RL Env Β· Task 3 Demo β•‘")
400
+ print(f"β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•{RESET}")
401
+ print()
402
+ print(f"{BOLD}Mode:{RESET} Automated demo | {BOLD}Seed:{RESET} {seed}")
403
+ print(f"{BOLD}Task:{RESET} Rule Checker")
404
+ print()
405
+
406
+ prop = obs.extra.get("property_english", "")
407
+ print(f"{BOLD}Contract :{RESET} {obs.contract_name}")
408
+ print(f"{BOLD}Property :{RESET} {prop[:100]}{'...' if len(prop) > 100 else ''}")
409
+ print(f"{BOLD}Goal :{RESET} Find the function that violates this property.")
410
+ print(DIVIDER)
411
+
412
+ if not script:
413
+ print(f"{YELLOW}No pre-written script for seed {seed}. Try seed 42 or 45.{RESET}")
414
+ return
415
+
416
+ for at, params, commentary in script:
417
+ time.sleep(delay)
418
+ print(f"\n{CYAN}β–Ά Agent thinking:{RESET} {commentary}")
419
+ time.sleep(delay * 0.5)
420
+ step_result = env.step(Action(action_type=at, params=params))
421
+ sobs = step_result.observation
422
+ print(DIVIDER)
423
+ print(f"{BOLD}Step {sobs.step_count:2d}{RESET} [{at.value}] "
424
+ f"r={step_result.reward.value:+.2f} cum={sobs.cumulative_reward:+.2f}")
425
+ result_text = sobs.last_action_result or ""
426
+ colour = GREEN if step_result.reward.value > 0 else YELLOW
427
+ for line in result_text.split("\n")[:6]:
428
+ print(f" {colour}{line[:90]}{RESET}")
429
+ print(DIVIDER)
430
+
431
+ if step_result.done:
432
+ _print_episode_summary(sobs)
433
+ return
env/schemas.py CHANGED
@@ -42,10 +42,9 @@ class ActionType(str, Enum):
42
  SUBMIT_PROPERTY = "submit_property" # scored 0–5, one attempt
43
 
44
  # ── Task 3 – Rule Checker ────────────────────────────────────────────────
45
- # TODO: Task 3
46
- # GET_FORMALIZED_PROPERTY = "get_formalized_property"
47
- # GET_FUNCTION_METADATA = "get_function_metadata"
48
- # SUBMIT_FUNCTION = "submit_function"
49
 
50
 
51
  class Action(BaseModel):
 
42
  SUBMIT_PROPERTY = "submit_property" # scored 0–5, one attempt
43
 
44
  # ── Task 3 – Rule Checker ────────────────────────────────────────────────
45
+ GET_FORMALIZED_PROPERTY = "get_formalized_property" # -0.03
46
+ GET_FUNCTION_METADATA = "get_function_metadata" # -0.05
47
+ SUBMIT_FUNCTION = "submit_function" # +5.0 / +1.5 / -1.5, one attempt
 
48
 
49
 
50
  class Action(BaseModel):
eval.py CHANGED
@@ -1,33 +1,32 @@
1
  """
2
  eval.py
3
  -------
4
- Evaluation harness for the Smart Contract Audit RL Environment.
5
 
6
- Runs oracle / partial / baseline agents against Task 1 and Task 2,
7
- verifying that grader scores form a clear ordering and that reward
8
- shaping is meaningful.
9
 
10
  Usage:
11
- python eval.py # Task 1 + Task 2, 8 episodes each
12
- python eval.py --task 1 # Task 1 only
13
- python eval.py --task 2 # Task 2 only
14
- python eval.py --episodes 16 # more episodes
15
- python eval.py --seed 0 --verbose # detailed per-step trace
16
- python eval.py --out results.json # custom output file
17
  """
18
 
19
  import argparse
20
  import json
21
- import sys
22
  from typing import Any, Dict, List
23
 
24
  from tasks.task1.environment import Task1Environment
25
  from tasks.task2.environment import Task2Environment
 
26
  from env.schemas import Action, ActionType
27
  from data.data_loader import (
28
  load_contracts,
29
  get_function_by_name,
30
  get_all_vulnerable_entries,
 
 
31
  )
32
 
33
 
@@ -36,12 +35,10 @@ from data.data_loader import (
36
  # ─────────────────────────────────────────────────────────────────────────────
37
 
38
  def oracle_t1(env: Task1Environment, seed: int, verbose: bool = False) -> Dict[str, Any]:
39
- """Always submits the exact ground-truth answer β†’ score = 1.0."""
40
  r = env.reset(seed=seed)
41
  obs = r.observation
42
- st = env.state()
43
- fn_name = st.target_function
44
-
45
  contracts = load_contracts()
46
  vuln_issue = ""
47
  for c in contracts:
@@ -49,42 +46,33 @@ def oracle_t1(env: Task1Environment, seed: int, verbose: bool = False) -> Dict[s
49
  if fn and fn.get("vulnerable"):
50
  vuln_issue = fn["vulnerability_details"]["issue"]
51
  break
52
-
53
  if verbose:
54
  print(f" {obs.contract_name}.{fn_name}() [{vuln_issue}]")
55
-
56
  env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
57
  env.step(Action(action_type=ActionType.GET_FUNCTION_CODE,
58
  params={"function_name": fn_name}))
59
  result = env.step(Action(action_type=ActionType.SUBMIT,
60
  params={"function_name": fn_name,
61
  "vulnerability_type": vuln_issue}))
62
-
63
  v = result.reward.value
64
  score = 1.0 if v >= 4.9 else (0.5 if v >= 0.9 else 0.0)
65
- return {
66
- "seed": seed,
67
- "contract": obs.contract_name,
68
- "target_function": fn_name,
69
- "vulnerability": vuln_issue,
70
- "grader_score": score,
71
- "cumulative_reward": result.observation.cumulative_reward,
72
- }
73
 
74
 
75
  def partial_t1(env: Task1Environment, seed: int) -> Dict[str, Any]:
76
- """Right function, wrong vuln type β†’ score = 0.5."""
77
  env.reset(seed=seed)
78
  fn_name = env.state().target_function
79
  result = env.step(Action(action_type=ActionType.SUBMIT,
80
- params={"function_name": fn_name,
81
- "vulnerability_type": "unknown"}))
82
  v = result.reward.value
83
  return {"seed": seed, "grader_score": 0.5 if v >= 0.9 else 0.0,
84
  "cumulative_reward": result.observation.cumulative_reward}
85
 
86
 
87
- def random_t1(env: Task1Environment, seed: int) -> Dict[str, Any]:
88
  """Always submits 'constructor' β†’ score = 0.0."""
89
  env.reset(seed=seed)
90
  result = env.step(Action(action_type=ActionType.SUBMIT,
@@ -99,12 +87,11 @@ def random_t1(env: Task1Environment, seed: int) -> Dict[str, Any]:
99
  # ─────────────────────────────────────────────────────────────────────────────
100
 
101
  def oracle_t2(env: Task2Environment, seed: int, verbose: bool = False) -> Dict[str, Any]:
102
- """Submits the exact ground-truth natural_language β†’ score β‰₯ 0.70."""
103
  r = env.reset(seed=seed)
104
  obs = r.observation
105
  fn_name = obs.extra["target_function"]
106
  contract = obs.contract_name
107
-
108
  contracts = load_contracts()
109
  gt_text = ""
110
  for c in contracts:
@@ -113,24 +100,15 @@ def oracle_t2(env: Task2Environment, seed: int, verbose: bool = False) -> Dict[s
113
  if fn and fn.get("property"):
114
  gt_text = fn["property"]["natural_language"]
115
  break
116
-
117
  if verbose:
118
  print(f" {contract}.{fn_name}()")
119
-
120
- # read code first (realistic browsing step)
121
  env.step(Action(action_type=ActionType.GET_FUNCTION_CODE))
122
  result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
123
  params={"property": gt_text}))
124
-
125
  r_val = result.reward.value
126
  score = round(r_val / 5.0, 4) if r_val > 0 else 0.0
127
- return {
128
- "seed": seed,
129
- "contract": contract,
130
- "function": fn_name,
131
- "grader_score": score,
132
- "cumulative_reward": result.observation.cumulative_reward,
133
- }
134
 
135
 
136
  def partial_t2(env: Task2Environment, seed: int) -> Dict[str, Any]:
@@ -148,16 +126,67 @@ def partial_t2(env: Task2Environment, seed: int) -> Dict[str, Any]:
148
  result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
149
  params={"property": comment}))
150
  r_val = result.reward.value
151
- score = round(r_val / 5.0, 4) if r_val > 0 else 0.0
152
- return {"seed": seed, "grader_score": score,
153
  "cumulative_reward": result.observation.cumulative_reward}
154
 
155
 
156
  def empty_t2(env: Task2Environment, seed: int) -> Dict[str, Any]:
157
  """Submits empty string β†’ score = 0.0."""
158
  env.reset(seed=seed)
159
- result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
160
- params={"property": ""}))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  return {"seed": seed, "grader_score": 0.0,
162
  "cumulative_reward": result.observation.cumulative_reward}
163
 
@@ -166,112 +195,144 @@ def empty_t2(env: Task2Environment, seed: int) -> Dict[str, Any]:
166
  # Evaluation runners
167
  # ─────────────────────────────────────────────────────────────────────────────
168
 
169
- def run_task1_eval(num_episodes: int, seed_offset: int, verbose: bool) -> Dict[str, Any]:
170
  print("\n" + "=" * 64)
171
  print("TASK 1 β€” Targeted Vulnerability Detection")
172
  print("=" * 64)
173
  contracts = load_contracts()
174
- entries = get_all_vulnerable_entries(contracts)
175
- print(f" Dataset: {len(contracts)} contracts, {len(entries)} vulnerable functions\n")
176
-
177
  env = Task1Environment()
178
 
179
- print("β–Ά Oracle agent (always submits correct answer):")
180
  oracle_eps = []
181
- for i in range(num_episodes):
182
- ep = oracle_t1(env, seed_offset + i, verbose=verbose)
183
  oracle_eps.append(ep)
184
  print(f" seed={ep['seed']:3d} {ep['contract']:12s}.{ep['target_function']:18s}"
185
  f" score={ep['grader_score']:.1f} reward={ep['cumulative_reward']:+.2f}")
186
- oracle_avg = sum(e["grader_score"] for e in oracle_eps) / num_episodes
187
- oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_eps) / num_episodes
188
- print(f"\n Oracle avg score : {oracle_avg:.3f} avg reward: {oracle_avg_r:+.2f}")
189
 
190
- print("\nβ–Ά Partial agent (right function, wrong vuln type β†’ 0.5):")
191
- partial_eps = [partial_t1(env, seed_offset + i) for i in range(num_episodes)]
192
- partial_avg = sum(e["grader_score"] for e in partial_eps) / num_episodes
193
- print(f" Partial avg score: {partial_avg:.3f}")
194
 
195
- print("\nβ–Ά Random agent (always wrong β†’ 0.0):")
196
- random_eps = [random_t1(env, seed_offset + i) for i in range(num_episodes)]
197
- random_avg = sum(e["grader_score"] for e in random_eps) / num_episodes
198
- print(f" Random avg score : {random_avg:.3f}")
199
 
200
  vuln_seen: Dict[str, int] = {}
201
  for ep in oracle_eps:
202
  v = ep.get("vulnerability", "unknown")
203
  vuln_seen[v] = vuln_seen.get(v, 0) + 1
204
- print("\nβ–Ά Vulnerability type coverage:")
205
  for v in sorted(vuln_seen):
206
  print(f" {vuln_seen[v]:2d}Γ— {v}")
207
 
208
- assert oracle_avg == 1.0, f"Oracle should be 1.0, got {oracle_avg}"
209
- assert partial_avg == 0.5, f"Partial should be 0.5, got {partial_avg}"
210
- assert random_avg == 0.0, f"Random should be 0.0, got {random_avg}"
211
- print("\n βœ… Task 1 score ordering: oracle(1.0) > partial(0.5) > random(0.0)")
212
 
213
  return {
214
  "task_id": "task1_vuln_detection",
215
  "oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_eps},
216
  "partial": {"avg_score": partial_avg, "episodes": partial_eps},
217
- "random": {"avg_score": random_avg, "episodes": random_eps},
218
  "vuln_coverage": vuln_seen,
219
  }
220
 
221
 
222
- def run_task2_eval(num_episodes: int, seed_offset: int, verbose: bool) -> Dict[str, Any]:
223
  print("\n" + "=" * 64)
224
  print("TASK 2 β€” Property Discovery")
225
  print("=" * 64)
226
- from data.data_loader import get_all_property_entries
227
  contracts = load_contracts()
228
- entries = get_all_property_entries(contracts)
229
- print(f" Dataset: {len(entries)} functions with properties\n")
230
-
231
  env = Task2Environment()
232
 
233
- print("β–Ά Oracle agent (submits ground-truth natural language):")
234
  oracle_eps = []
235
- for i in range(num_episodes):
236
- ep = oracle_t2(env, seed_offset + i, verbose=verbose)
237
  oracle_eps.append(ep)
238
  icon = "βœ…" if ep["grader_score"] >= 0.65 else "⚠️ "
239
  print(f" {icon} seed={ep['seed']:3d} {ep['contract']:12s}.{ep['function']:18s}"
240
  f" score={ep['grader_score']:.3f} reward={ep['cumulative_reward']:+.2f}")
241
- oracle_avg = sum(e["grader_score"] for e in oracle_eps) / num_episodes
242
- oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_eps) / num_episodes
243
- print(f"\n Oracle avg score : {oracle_avg:.3f} avg reward: {oracle_avg_r:+.2f}")
244
-
245
- print("\nβ–Ά Partial agent (submits NatSpec comment β€” partial signal):")
246
- partial_eps = [partial_t2(env, seed_offset + i) for i in range(num_episodes)]
247
- partial_avg = sum(e["grader_score"] for e in partial_eps) / num_episodes
248
- partial_avg_r = sum(e["cumulative_reward"] for e in partial_eps) / num_episodes
249
- print(f" Partial avg score: {partial_avg:.3f} avg reward: {partial_avg_r:+.2f}")
250
-
251
- print("\nβ–Ά Empty agent (submits nothing β†’ 0.0):")
252
- empty_eps = [empty_t2(env, seed_offset + i) for i in range(num_episodes)]
253
- empty_avg = sum(e["grader_score"] for e in empty_eps) / num_episodes
254
- print(f" Empty avg score : {empty_avg:.3f}")
255
-
256
- fn_seen: Dict[str, int] = {}
257
- for ep in oracle_eps:
258
- fn_seen[ep["function"]] = fn_seen.get(ep["function"], 0) + 1
259
- print("\nβ–Ά Function coverage:")
260
- for fn in sorted(fn_seen):
261
- print(f" {fn_seen[fn]:2d}Γ— {fn}")
262
-
263
- assert oracle_avg > 0.60, f"Oracle avg {oracle_avg:.3f} should be > 0.60"
264
- assert oracle_avg > partial_avg, "Oracle should beat partial"
265
- assert partial_avg >= empty_avg, "Partial should be >= empty"
266
- assert empty_avg == 0.0, f"Empty should be 0.0, got {empty_avg}"
267
- print(f"\n βœ… Task 2 score ordering: oracle({oracle_avg:.3f}) > partial({partial_avg:.3f}) > empty(0.0)")
268
 
269
  return {
270
  "task_id": "task2_property_discovery",
271
  "oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_eps},
272
  "partial": {"avg_score": partial_avg, "avg_reward": partial_avg_r, "episodes": partial_eps},
273
  "empty": {"avg_score": empty_avg, "episodes": empty_eps},
274
- "fn_coverage": fn_seen,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  }
276
 
277
 
@@ -281,45 +342,37 @@ def run_task2_eval(num_episodes: int, seed_offset: int, verbose: bool) -> Dict[s
281
 
282
  def main():
283
  parser = argparse.ArgumentParser(
284
- description="Evaluate Task 1 and/or Task 2 of the SC Audit RL Environment"
285
  )
286
- parser.add_argument("--episodes", type=int, default=8,
287
- help="Episodes per agent tier (default: 8)")
288
- parser.add_argument("--seed", type=int, default=42,
289
- help="Starting RNG seed (default: 42)")
290
- parser.add_argument("--task", choices=["1", "2", "all"], default="all",
291
- help="Which task(s) to evaluate (default: all)")
292
- parser.add_argument("--verbose", action="store_true",
293
- help="Print per-episode target details")
294
- parser.add_argument("--out", default="eval_results.json",
295
- help="Output file (default: eval_results.json)")
296
  args = parser.parse_args()
297
 
298
- report: Dict[str, Any] = {
299
- "num_episodes": args.episodes,
300
- "seed_offset": args.seed,
301
- }
302
 
303
  if args.task in ("1", "all"):
304
  report["task1"] = run_task1_eval(args.episodes, args.seed, args.verbose)
305
-
306
  if args.task in ("2", "all"):
307
  report["task2"] = run_task2_eval(args.episodes, args.seed, args.verbose)
 
 
308
 
309
- # ── Summary ──────────────────────────────────────────────────────────────
310
  print("\n" + "=" * 64)
311
  print("EVALUATION COMPLETE")
312
  print("=" * 64)
313
- if "task1" in report:
314
- t1 = report["task1"]
315
- print(f" Task 1 oracle={t1['oracle']['avg_score']:.3f} "
316
- f"partial={t1['partial']['avg_score']:.3f} "
317
- f"random={t1['random']['avg_score']:.3f}")
318
- if "task2" in report:
319
- t2 = report["task2"]
320
- print(f" Task 2 oracle={t2['oracle']['avg_score']:.3f} "
321
- f"partial={t2['partial']['avg_score']:.3f} "
322
- f"empty={t2['empty']['avg_score']:.3f}")
323
 
324
  with open(args.out, "w") as f:
325
  json.dump(report, f, indent=2)
 
1
  """
2
  eval.py
3
  -------
4
+ Evaluation harness for all three tasks.
5
 
6
+ Runs oracle / partial / baseline agents, verifying score orderings and
7
+ that reward shaping is meaningful across the trajectory.
 
8
 
9
  Usage:
10
+ python eval.py # all tasks, 8 episodes each
11
+ python eval.py --task 1|2|3 # single task
12
+ python eval.py --episodes 16 --verbose
13
+ python eval.py --out results.json
 
 
14
  """
15
 
16
  import argparse
17
  import json
 
18
  from typing import Any, Dict, List
19
 
20
  from tasks.task1.environment import Task1Environment
21
  from tasks.task2.environment import Task2Environment
22
+ from tasks.task3.environment import Task3Environment
23
  from env.schemas import Action, ActionType
24
  from data.data_loader import (
25
  load_contracts,
26
  get_function_by_name,
27
  get_all_vulnerable_entries,
28
+ get_all_property_entries,
29
+ get_all_task3_entries,
30
  )
31
 
32
 
 
35
  # ─────────────────────────────────────────────────────────────────────────────
36
 
37
  def oracle_t1(env: Task1Environment, seed: int, verbose: bool = False) -> Dict[str, Any]:
38
+ """Submits the exact ground-truth function + vulnerability β†’ score = 1.0."""
39
  r = env.reset(seed=seed)
40
  obs = r.observation
41
+ fn_name = env.state().target_function
 
 
42
  contracts = load_contracts()
43
  vuln_issue = ""
44
  for c in contracts:
 
46
  if fn and fn.get("vulnerable"):
47
  vuln_issue = fn["vulnerability_details"]["issue"]
48
  break
 
49
  if verbose:
50
  print(f" {obs.contract_name}.{fn_name}() [{vuln_issue}]")
 
51
  env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
52
  env.step(Action(action_type=ActionType.GET_FUNCTION_CODE,
53
  params={"function_name": fn_name}))
54
  result = env.step(Action(action_type=ActionType.SUBMIT,
55
  params={"function_name": fn_name,
56
  "vulnerability_type": vuln_issue}))
 
57
  v = result.reward.value
58
  score = 1.0 if v >= 4.9 else (0.5 if v >= 0.9 else 0.0)
59
+ return {"seed": seed, "contract": obs.contract_name, "target_function": fn_name,
60
+ "vulnerability": vuln_issue, "grader_score": score,
61
+ "cumulative_reward": result.observation.cumulative_reward}
 
 
 
 
 
62
 
63
 
64
  def partial_t1(env: Task1Environment, seed: int) -> Dict[str, Any]:
65
+ """Right function, 'unknown' vuln type β†’ score = 0.5."""
66
  env.reset(seed=seed)
67
  fn_name = env.state().target_function
68
  result = env.step(Action(action_type=ActionType.SUBMIT,
69
+ params={"function_name": fn_name, "vulnerability_type": "unknown"}))
 
70
  v = result.reward.value
71
  return {"seed": seed, "grader_score": 0.5 if v >= 0.9 else 0.0,
72
  "cumulative_reward": result.observation.cumulative_reward}
73
 
74
 
75
+ def wrong_t1(env: Task1Environment, seed: int) -> Dict[str, Any]:
76
  """Always submits 'constructor' β†’ score = 0.0."""
77
  env.reset(seed=seed)
78
  result = env.step(Action(action_type=ActionType.SUBMIT,
 
87
  # ─────────────────────────────────────────────────────────────────────────────
88
 
89
  def oracle_t2(env: Task2Environment, seed: int, verbose: bool = False) -> Dict[str, Any]:
90
+ """Submits ground-truth natural_language β†’ score β‰₯ 0.70."""
91
  r = env.reset(seed=seed)
92
  obs = r.observation
93
  fn_name = obs.extra["target_function"]
94
  contract = obs.contract_name
 
95
  contracts = load_contracts()
96
  gt_text = ""
97
  for c in contracts:
 
100
  if fn and fn.get("property"):
101
  gt_text = fn["property"]["natural_language"]
102
  break
 
103
  if verbose:
104
  print(f" {contract}.{fn_name}()")
 
 
105
  env.step(Action(action_type=ActionType.GET_FUNCTION_CODE))
106
  result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
107
  params={"property": gt_text}))
 
108
  r_val = result.reward.value
109
  score = round(r_val / 5.0, 4) if r_val > 0 else 0.0
110
+ return {"seed": seed, "contract": contract, "function": fn_name,
111
+ "grader_score": score, "cumulative_reward": result.observation.cumulative_reward}
 
 
 
 
 
112
 
113
 
114
  def partial_t2(env: Task2Environment, seed: int) -> Dict[str, Any]:
 
126
  result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY,
127
  params={"property": comment}))
128
  r_val = result.reward.value
129
+ return {"seed": seed, "grader_score": round(r_val / 5.0, 4) if r_val > 0 else 0.0,
 
130
  "cumulative_reward": result.observation.cumulative_reward}
131
 
132
 
133
  def empty_t2(env: Task2Environment, seed: int) -> Dict[str, Any]:
134
  """Submits empty string β†’ score = 0.0."""
135
  env.reset(seed=seed)
136
+ result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY, params={"property": ""}))
137
+ return {"seed": seed, "grader_score": 0.0,
138
+ "cumulative_reward": result.observation.cumulative_reward}
139
+
140
+
141
+ # ─────────────────────────────────────────────────────────────────────────────
142
+ # Task 3 agents
143
+ # ─────────────────────────────────────────────────────────────────────────────
144
+
145
+ def oracle_t3(env: Task3Environment, seed: int, verbose: bool = False) -> Dict[str, Any]:
146
+ """Always submits the exact target function β†’ score = 1.0."""
147
+ r = env.reset(seed=seed)
148
+ obs = r.observation
149
+ fn_name = env.state().target_function
150
+ contract = obs.contract_name
151
+ if verbose:
152
+ prop = obs.extra.get("property_english", "")[:60]
153
+ print(f" {contract}.{fn_name}() \"{prop}\"")
154
+ env.step(Action(action_type=ActionType.GET_FORMALIZED_PROPERTY))
155
+ env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
156
+ result = env.step(Action(action_type=ActionType.SUBMIT_FUNCTION,
157
+ params={"function_name": fn_name}))
158
+ v = result.reward.value
159
+ score = 1.0 if v >= 4.9 else (0.3 if v >= 1.0 else 0.0)
160
+ return {"seed": seed, "contract": contract, "target_function": fn_name,
161
+ "grader_score": score, "cumulative_reward": result.observation.cumulative_reward}
162
+
163
+
164
+ def subfunction_t3(env: Task3Environment, seed: int) -> Dict[str, Any]:
165
+ """Submits the first partial-credit subfunction if it exists, else 'constructor'."""
166
+ r = env.reset(seed=seed)
167
+ obs = r.observation
168
+ contracts = load_contracts()
169
+ partial_fns = []
170
+ for c in contracts:
171
+ if c["contract_name"] == obs.contract_name:
172
+ fn = get_function_by_name(c, env.state().target_function)
173
+ if fn:
174
+ partial_fns = fn.get("task3", {}).get("partial_credit_functions", [])
175
+ break
176
+ submit_name = partial_fns[0] if partial_fns else "constructor"
177
+ result = env.step(Action(action_type=ActionType.SUBMIT_FUNCTION,
178
+ params={"function_name": submit_name}))
179
+ v = result.reward.value
180
+ score = 1.0 if v >= 4.9 else (0.3 if v >= 1.0 else 0.0)
181
+ return {"seed": seed, "grader_score": score, "submitted": submit_name,
182
+ "cumulative_reward": result.observation.cumulative_reward}
183
+
184
+
185
+ def wrong_t3(env: Task3Environment, seed: int) -> Dict[str, Any]:
186
+ """Always submits 'constructor' β†’ score = 0.0."""
187
+ env.reset(seed=seed)
188
+ result = env.step(Action(action_type=ActionType.SUBMIT_FUNCTION,
189
+ params={"function_name": "constructor"}))
190
  return {"seed": seed, "grader_score": 0.0,
191
  "cumulative_reward": result.observation.cumulative_reward}
192
 
 
195
  # Evaluation runners
196
  # ─────────────────────────────────────────────────────────────────────────────
197
 
198
+ def run_task1_eval(n: int, seed_offset: int, verbose: bool) -> Dict[str, Any]:
199
  print("\n" + "=" * 64)
200
  print("TASK 1 β€” Targeted Vulnerability Detection")
201
  print("=" * 64)
202
  contracts = load_contracts()
203
+ print(f" Dataset: {len(contracts)} contracts, "
204
+ f"{len(get_all_vulnerable_entries(contracts))} vulnerable functions\n")
 
205
  env = Task1Environment()
206
 
207
+ print("β–Ά Oracle (correct function + correct vuln type β†’ 1.0):")
208
  oracle_eps = []
209
+ for i in range(n):
210
+ ep = oracle_t1(env, seed_offset + i, verbose)
211
  oracle_eps.append(ep)
212
  print(f" seed={ep['seed']:3d} {ep['contract']:12s}.{ep['target_function']:18s}"
213
  f" score={ep['grader_score']:.1f} reward={ep['cumulative_reward']:+.2f}")
214
+ oracle_avg = sum(e["grader_score"] for e in oracle_eps) / n
215
+ oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_eps) / n
216
+ print(f"\n Oracle avg: {oracle_avg:.3f} reward: {oracle_avg_r:+.2f}")
217
 
218
+ print("\nβ–Ά Partial (right function, wrong vuln β†’ 0.5):")
219
+ partial_eps = [partial_t1(env, seed_offset + i) for i in range(n)]
220
+ partial_avg = sum(e["grader_score"] for e in partial_eps) / n
221
+ print(f" Partial avg: {partial_avg:.3f}")
222
 
223
+ print("\nβ–Ά Wrong (always 'constructor' β†’ 0.0):")
224
+ wrong_eps = [wrong_t1(env, seed_offset + i) for i in range(n)]
225
+ wrong_avg = sum(e["grader_score"] for e in wrong_eps) / n
226
+ print(f" Wrong avg: {wrong_avg:.3f}")
227
 
228
  vuln_seen: Dict[str, int] = {}
229
  for ep in oracle_eps:
230
  v = ep.get("vulnerability", "unknown")
231
  vuln_seen[v] = vuln_seen.get(v, 0) + 1
232
+ print("\nβ–Ά Vulnerability coverage:")
233
  for v in sorted(vuln_seen):
234
  print(f" {vuln_seen[v]:2d}Γ— {v}")
235
 
236
+ assert oracle_avg == 1.0
237
+ assert partial_avg == 0.5
238
+ assert wrong_avg == 0.0
239
+ print("\n βœ… Task 1: oracle(1.0) > partial(0.5) > wrong(0.0)")
240
 
241
  return {
242
  "task_id": "task1_vuln_detection",
243
  "oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_eps},
244
  "partial": {"avg_score": partial_avg, "episodes": partial_eps},
245
+ "wrong": {"avg_score": wrong_avg, "episodes": wrong_eps},
246
  "vuln_coverage": vuln_seen,
247
  }
248
 
249
 
250
+ def run_task2_eval(n: int, seed_offset: int, verbose: bool) -> Dict[str, Any]:
251
  print("\n" + "=" * 64)
252
  print("TASK 2 β€” Property Discovery")
253
  print("=" * 64)
 
254
  contracts = load_contracts()
255
+ print(f" Dataset: {len(get_all_property_entries(contracts))} property entries\n")
 
 
256
  env = Task2Environment()
257
 
258
+ print("β–Ά Oracle (submits ground-truth natural language):")
259
  oracle_eps = []
260
+ for i in range(n):
261
+ ep = oracle_t2(env, seed_offset + i, verbose)
262
  oracle_eps.append(ep)
263
  icon = "βœ…" if ep["grader_score"] >= 0.65 else "⚠️ "
264
  print(f" {icon} seed={ep['seed']:3d} {ep['contract']:12s}.{ep['function']:18s}"
265
  f" score={ep['grader_score']:.3f} reward={ep['cumulative_reward']:+.2f}")
266
+ oracle_avg = sum(e["grader_score"] for e in oracle_eps) / n
267
+ oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_eps) / n
268
+ print(f"\n Oracle avg: {oracle_avg:.3f} reward: {oracle_avg_r:+.2f}")
269
+
270
+ print("\nβ–Ά Partial (submits NatSpec comment):")
271
+ partial_eps = [partial_t2(env, seed_offset + i) for i in range(n)]
272
+ partial_avg = sum(e["grader_score"] for e in partial_eps) / n
273
+ partial_avg_r = sum(e["cumulative_reward"] for e in partial_eps) / n
274
+ print(f" Partial avg: {partial_avg:.3f} reward: {partial_avg_r:+.2f}")
275
+
276
+ print("\nβ–Ά Empty (submits nothing β†’ 0.0):")
277
+ empty_eps = [empty_t2(env, seed_offset + i) for i in range(n)]
278
+ empty_avg = sum(e["grader_score"] for e in empty_eps) / n
279
+ print(f" Empty avg: {empty_avg:.3f}")
280
+
281
+ assert oracle_avg > 0.60
282
+ assert oracle_avg > partial_avg
283
+ assert empty_avg == 0.0
284
+ print(f"\n βœ… Task 2: oracle({oracle_avg:.3f}) > partial({partial_avg:.3f}) > empty(0.0)")
 
 
 
 
 
 
 
 
285
 
286
  return {
287
  "task_id": "task2_property_discovery",
288
  "oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_eps},
289
  "partial": {"avg_score": partial_avg, "avg_reward": partial_avg_r, "episodes": partial_eps},
290
  "empty": {"avg_score": empty_avg, "episodes": empty_eps},
291
+ }
292
+
293
+
294
+ def run_task3_eval(n: int, seed_offset: int, verbose: bool) -> Dict[str, Any]:
295
+ print("\n" + "=" * 64)
296
+ print("TASK 3 β€” Rule Checker")
297
+ print("=" * 64)
298
+ contracts = load_contracts()
299
+ print(f" Dataset: {len(get_all_task3_entries(contracts))} rule-check episodes\n")
300
+ env = Task3Environment()
301
+
302
+ print("β–Ά Oracle (submits exact target function β†’ 1.0):")
303
+ oracle_eps = []
304
+ for i in range(n):
305
+ ep = oracle_t3(env, seed_offset + i, verbose)
306
+ oracle_eps.append(ep)
307
+ print(f" seed={ep['seed']:3d} {ep['contract']:12s}.{ep['target_function']:18s}"
308
+ f" score={ep['grader_score']:.1f} reward={ep['cumulative_reward']:+.2f}")
309
+ oracle_avg = sum(e["grader_score"] for e in oracle_eps) / n
310
+ oracle_avg_r = sum(e["cumulative_reward"] for e in oracle_eps) / n
311
+ print(f"\n Oracle avg: {oracle_avg:.3f} reward: {oracle_avg_r:+.2f}")
312
+
313
+ print("\nβ–Ά Subfunction (partial-credit callee or fallback to wrong):")
314
+ sub_eps = [subfunction_t3(env, seed_offset + i) for i in range(n)]
315
+ sub_avg = sum(e["grader_score"] for e in sub_eps) / n
316
+ sub_avg_r = sum(e["cumulative_reward"] for e in sub_eps) / n
317
+ submitted = list({e.get("submitted", "?") for e in sub_eps})
318
+ print(f" Subfunction avg: {sub_avg:.3f} reward: {sub_avg_r:+.2f} "
319
+ f"submitted fns: {submitted}")
320
+
321
+ print("\nβ–Ά Wrong (always 'constructor' β†’ 0.0):")
322
+ wrong_eps = [wrong_t3(env, seed_offset + i) for i in range(n)]
323
+ wrong_avg = sum(e["grader_score"] for e in wrong_eps) / n
324
+ print(f" Wrong avg: {wrong_avg:.3f}")
325
+
326
+ assert oracle_avg == 1.0
327
+ assert 0.0 <= sub_avg <= oracle_avg
328
+ assert wrong_avg == 0.0
329
+ print(f"\n βœ… Task 3: oracle(1.0) β‰₯ subfunction({sub_avg:.3f}) > wrong(0.0)")
330
+
331
+ return {
332
+ "task_id": "task3_rule_checker",
333
+ "oracle": {"avg_score": oracle_avg, "avg_reward": oracle_avg_r, "episodes": oracle_eps},
334
+ "subfunction": {"avg_score": sub_avg, "avg_reward": sub_avg_r, "episodes": sub_eps},
335
+ "wrong": {"avg_score": wrong_avg, "episodes": wrong_eps},
336
  }
337
 
338
 
 
342
 
343
  def main():
344
  parser = argparse.ArgumentParser(
345
+ description="Evaluate Task 1, 2, and/or 3 of the SC Audit RL Environment"
346
  )
347
+ parser.add_argument("--episodes", type=int, default=8)
348
+ parser.add_argument("--seed", type=int, default=42)
349
+ parser.add_argument("--task", choices=["1", "2", "3", "all"], default="all")
350
+ parser.add_argument("--verbose", action="store_true")
351
+ parser.add_argument("--out", default="eval_results.json")
 
 
 
 
 
352
  args = parser.parse_args()
353
 
354
+ report: Dict[str, Any] = {"num_episodes": args.episodes, "seed_offset": args.seed}
 
 
 
355
 
356
  if args.task in ("1", "all"):
357
  report["task1"] = run_task1_eval(args.episodes, args.seed, args.verbose)
 
358
  if args.task in ("2", "all"):
359
  report["task2"] = run_task2_eval(args.episodes, args.seed, args.verbose)
360
+ if args.task in ("3", "all"):
361
+ report["task3"] = run_task3_eval(args.episodes, args.seed, args.verbose)
362
 
 
363
  print("\n" + "=" * 64)
364
  print("EVALUATION COMPLETE")
365
  print("=" * 64)
366
+ for label, key, tiers in [
367
+ ("Task 1", "task1", ["oracle", "partial", "wrong"]),
368
+ ("Task 2", "task2", ["oracle", "partial", "empty"]),
369
+ ("Task 3", "task3", ["oracle", "subfunction", "wrong"]),
370
+ ]:
371
+ if key in report:
372
+ scores = " ".join(
373
+ f"{t}={report[key][t]['avg_score']:.3f}" for t in tiers
374
+ )
375
+ print(f" {label} {scores}")
376
 
377
  with open(args.out, "w") as f:
378
  json.dump(report, f, indent=2)
inference.py CHANGED
@@ -2,8 +2,8 @@
2
  inference.py
3
  ------------
4
  Baseline inference script for the Smart Contract Audit RL Environment.
5
- Implements Task 1 (Vulnerability Detection) and Task 2 (Property Discovery).
6
- Task 3 is a placeholder that returns 0.0.
7
 
8
  Environment variables:
9
  API_BASE_URL – LLM API endpoint (e.g. https://api.openai.com/v1)
@@ -30,6 +30,7 @@ from openai import OpenAI
30
 
31
  from tasks.task1.environment import Task1Environment
32
  from tasks.task2.environment import Task2Environment
 
33
  from env.schemas import Action, ActionType
34
 
35
  # ─────────────────────────────────────────────────────────────────────────────
@@ -261,14 +262,103 @@ def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
261
  "avg_grader_score": avg_s, "avg_cumulative_reward": avg_r}
262
 
263
 
264
- def run_task3_placeholder() -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  print("\n" + "="*60)
266
- print("TASK 3: Rule Checker [PLACEHOLDER β€” not implemented]")
267
  print("="*60)
268
- print(" Skipping. Score: 0.0")
 
 
 
 
 
269
  return {"task_id": "task3_rule_checker", "name": "Rule Checker",
270
- "status": "placeholder", "num_episodes": 0, "episodes": [],
271
- "avg_grader_score": 0.0, "avg_cumulative_reward": 0.0}
272
 
273
 
274
  # ─────────────────────────────────────────────────────────────────────────────
@@ -281,15 +371,15 @@ def main():
281
 
282
  t1 = run_task1(NUM_EPISODES)
283
  t2 = run_task2(NUM_EPISODES)
284
- t3 = run_task3_placeholder()
285
 
286
  results = {
287
  "model": MODEL_NAME, "base_url": API_BASE_URL,
288
  "tasks": [t1, t2, t3],
289
  }
290
 
291
- active = [t for t in results["tasks"] if t["status"] == "active"]
292
- overall = sum(t["avg_grader_score"] for t in active) / len(active) if active else 0.0
293
  results["overall_avg_score"] = overall
294
 
295
  print("\n" + "="*60)
 
2
  inference.py
3
  ------------
4
  Baseline inference script for the Smart Contract Audit RL Environment.
5
+ Implements Task 1 (Vulnerability Detection), Task 2 (Property Discovery),
6
+ and Task 3 (Rule Checker).
7
 
8
  Environment variables:
9
  API_BASE_URL – LLM API endpoint (e.g. https://api.openai.com/v1)
 
30
 
31
  from tasks.task1.environment import Task1Environment
32
  from tasks.task2.environment import Task2Environment
33
+ from tasks.task3.environment import Task3Environment
34
  from env.schemas import Action, ActionType
35
 
36
  # ─────────────────────────────────────────────────────────────────────────────
 
262
  "avg_grader_score": avg_s, "avg_cumulative_reward": avg_r}
263
 
264
 
265
+ T3_SYSTEM = """You are a smart contract security auditor checking rule compliance.
266
+
267
+ You are given a Solidity contract and a property (rule) in natural English.
268
+ Your task is to find the ONE function that violates this property.
269
+
270
+ ## Actions (respond with JSON only, ONE action per turn):
271
+ {"action": "list_functions", "params": {}}
272
+ {"action": "get_formalized_property", "params": {}}
273
+ {"action": "get_function_metadata", "params": {"function_name": "<n>"}}
274
+ {"action": "get_function_code", "params": {"function_name": "<n>"}}
275
+ {"action": "get_state_variable", "params": {"variable_name": "<n>"}}
276
+ {"action": "get_call_graph", "params": {}}
277
+ {"action": "submit_function", "params": {"function_name": "<n>"}}
278
+
279
+ ## Strategy:
280
+ 1. Read the property shown as property_english in the observation.
281
+ 2. list_functions to survey candidates.
282
+ 3. get_formalized_property for the precise pre/post-condition (cheap: -0.03).
283
+ 4. get_function_code on the 1-2 most suspicious functions.
284
+ 5. submit_function when confident β€” ONE attempt only.
285
+
286
+ Clues: missing require, no access modifier, unchecked external call, unbounded array,
287
+ tx.origin auth, integer overflow, timestamp manipulation, reentrancy ordering.
288
+
289
+ Respond ONLY with valid JSON. No markdown, no explanation."""
290
+
291
+
292
+ def _t3_user_msg(obs: Dict[str, Any]) -> str:
293
+ extra = obs.get("extra", {})
294
+ return (
295
+ f"Contract : {obs['contract_name']}\n"
296
+ f"Property : {extra.get('property_english', '(no property)')}\n"
297
+ f"Step: {obs['step_count']} | Reward: {obs['cumulative_reward']:.2f}\n\n"
298
+ f"Last action: {obs['last_action'] or 'None'}\n"
299
+ f"Result:\n{obs['last_action_result'] or 'Episode started.'}"
300
+ )
301
+
302
+
303
+ def run_t3_episode(env: Task3Environment, seed: int, ep: int) -> Dict[str, Any]:
304
+ r = env.reset(seed=seed)
305
+ obs = r.observation.model_dump()
306
+ prop_preview = obs['extra'].get('property_english', '')[:55]
307
+ print(f" ep={ep} seed={seed} {obs['contract_name']} \"{prop_preview}...\"")
308
+
309
+ messages = [{"role": "system", "content": T3_SYSTEM}]
310
+ grader_score = 0.0
311
+ cum_reward = 0.0
312
+
313
+ for step in range(15):
314
+ messages.append({"role": "user", "content": _t3_user_msg(obs)})
315
+ try:
316
+ resp = client.chat.completions.create(
317
+ model=MODEL_NAME, messages=messages,
318
+ max_tokens=200, temperature=0.0,
319
+ )
320
+ raw = resp.choices[0].message.content.strip()
321
+ except Exception as e:
322
+ print(f" LLM error: {e}", file=sys.stderr)
323
+ break
324
+
325
+ try:
326
+ parsed = json.loads(raw)
327
+ at = ActionType(parsed["action"])
328
+ params = parsed.get("params", {})
329
+ except Exception:
330
+ at, params = ActionType.LIST_FUNCTIONS, {}
331
+
332
+ messages.append({"role": "assistant", "content": raw})
333
+ result = env.step(Action(action_type=at, params=params))
334
+ obs = result.observation.model_dump()
335
+ print(f" step {step+1:2d}: {at.value:28s} r={result.reward.value:+.2f}")
336
+
337
+ if result.done:
338
+ v = result.reward.value
339
+ grader_score = 1.0 if v >= 4.9 else (0.3 if v >= 1.0 else 0.0)
340
+ cum_reward = obs["cumulative_reward"]
341
+ break
342
+ time.sleep(0.3)
343
+
344
+ print(f" β†’ grader_score={grader_score:.1f} cum_reward={cum_reward:.2f}")
345
+ return {"episode": ep, "seed": seed, "contract": obs["contract_name"],
346
+ "grader_score": grader_score, "cumulative_reward": cum_reward}
347
+
348
+
349
+ def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]:
350
  print("\n" + "="*60)
351
+ print("TASK 3: Rule Checker")
352
  print("="*60)
353
+ env = Task3Environment()
354
+ episodes = [run_t3_episode(env, 42 + i, i + 1) for i in range(n)]
355
+ avg_s = sum(e["grader_score"] for e in episodes) / n
356
+ avg_r = sum(e["cumulative_reward"] for e in episodes) / n
357
+ print(f"\n Avg grader score : {avg_s:.3f}")
358
+ print(f" Avg cum reward : {avg_r:.2f}")
359
  return {"task_id": "task3_rule_checker", "name": "Rule Checker",
360
+ "status": "active", "num_episodes": n, "episodes": episodes,
361
+ "avg_grader_score": avg_s, "avg_cumulative_reward": avg_r}
362
 
363
 
364
  # ─────────────────────────────────────────────────────────────────────────────
 
371
 
372
  t1 = run_task1(NUM_EPISODES)
373
  t2 = run_task2(NUM_EPISODES)
374
+ t3 = run_task3(NUM_EPISODES)
375
 
376
  results = {
377
  "model": MODEL_NAME, "base_url": API_BASE_URL,
378
  "tasks": [t1, t2, t3],
379
  }
380
 
381
+ active = results["tasks"]
382
+ overall = sum(t["avg_grader_score"] for t in active) / len(active)
383
  results["overall_avg_score"] = overall
384
 
385
  print("\n" + "="*60)
openenv.yaml CHANGED
@@ -1,11 +1,10 @@
1
  name: smart-contract-audit-env
2
- version: "1.1.0"
3
  description: >
4
  Reinforcement learning environment for smart contract security analysis.
5
  Agents interact with real-world Solidity contract data from Certora-audited
6
- projects, learning to detect vulnerabilities and discover correctness
7
- properties β€” tasks that professional auditors perform daily.
8
-
9
  author: "SmartAudit Team"
10
  license: MIT
11
 
@@ -37,10 +36,10 @@ tasks:
37
  - id: task3_rule_checker
38
  name: Rule Checker
39
  difficulty: easy
40
- status: placeholder
41
  description: >
42
- Given a natural-language property and a Solidity file, identify the
43
- function that violates that property.
44
  max_steps: 15
45
  reward_range: [-5.0, 5.0]
46
  grader: tasks/task3/grader.py
@@ -49,70 +48,63 @@ tasks:
49
  observation_space:
50
  type: object
51
  properties:
52
- task_id: {type: string, description: Active task identifier}
53
- contract_name: {type: string, description: Solidity contract name}
54
- contract_description: {type: string, description: Human-readable contract description}
55
- available_actions: {type: array, items: {type: string}, description: Valid action types}
56
  last_action: {type: string, nullable: true}
57
  last_action_result: {type: string, nullable: true}
58
  step_count: {type: integer}
59
  cumulative_reward: {type: number}
60
  done: {type: boolean}
61
- extra: {type: object, description: Task-specific hints}
62
 
63
  action_space:
64
  task1:
65
- type: object
66
- actions:
67
- list_functions: {params: {}, reward: -0.05}
68
- get_function_code: {params: {function_name: string}, reward: "+0.05 / -0.10"}
69
- get_function_summary: {params: {function_name: string}, reward: "+0.03 / -0.05"}
70
- get_file_metadata: {params: {}, reward: -0.04}
71
- get_state_variable: {params: {variable_name: "string (opt)"}, reward: -0.05}
72
- get_call_graph: {params: {}, reward: -0.08}
73
- submit: {params: {function_name: str, vulnerability_type: str}, reward: "+5.0 / +1.0 / -1.5"}
74
  task2:
75
- type: object
76
- actions:
77
- get_function_code: {params: {}, reward: -0.06}
78
- get_function_natspec: {params: {}, reward: -0.08}
79
- get_file_natspec: {params: {}, reward: -0.03}
80
- get_related_functions: {params: {}, reward: -0.06}
81
- get_io: {params: {}, reward: -0.04}
82
- get_similar_rule: {params: {}, reward: -0.20}
83
- submit_property: {params: {property: string}, reward: "0.0–5.0 (keyword-weighted)"}
 
 
 
 
 
 
84
 
85
  reward:
86
  type: shaped
87
- description: >
88
- Per-step costs encourage efficient exploration. Positive shaping rewards
89
- fire when the agent inspects the actual target. Terminal rewards reflect
90
- grader score accuracy.
91
  task1_shaping:
92
- list_functions: -0.05
93
- get_function_code_wrong: -0.10
94
  get_function_code_correct: +0.05
95
- get_function_summary_wrong: -0.05
96
  get_function_summary_correct: +0.03
97
- get_file_metadata: -0.04
98
- get_state_variable: -0.05
99
- get_call_graph: -0.08
100
- repeated_query: -0.40
101
  task1_terminal:
102
- correct_submission: +5.0
103
- partial_submission: +1.0
104
- wrong_submission: -1.5
105
- task2_shaping:
106
- get_function_code: -0.06
107
- get_function_natspec: -0.08
108
- get_file_natspec: -0.03
109
- get_related_functions: -0.06
110
- get_io: -0.04
111
- get_similar_rule: -0.20
112
- repeated_query: -0.40
113
  task2_terminal:
114
- score_range: [0.0, 5.0]
115
- formula: "score * 5.0 where score = 0.70*(key_matches/total_key) + 0.30*(bonus_matches/total_bonus)"
 
 
 
 
116
 
117
  data:
118
  source: "Certora audited DeFi projects"
@@ -120,6 +112,7 @@ data:
120
  num_contracts: 4
121
  num_vulnerable_functions: 8
122
  num_property_functions: 11
 
123
  vulnerability_types:
124
  - Reentrancy
125
  - Missing access control
@@ -132,14 +125,14 @@ data:
132
 
133
  interface:
134
  http:
135
- reset: POST /reset
136
- step: POST /step
137
- state: GET /state
138
- tasks: GET /tasks
139
- health: GET /health
140
- action_space: GET /action_space?task_id=<id>
141
- observation_space: GET /observation_space
142
  python:
143
- reset: env.reset(seed=None) -> ResetResult
144
- step: env.step(action) -> StepResult
145
- state: env.state() -> StateResult
 
1
  name: smart-contract-audit-env
2
+ version: "1.2.0"
3
  description: >
4
  Reinforcement learning environment for smart contract security analysis.
5
  Agents interact with real-world Solidity contract data from Certora-audited
6
+ projects, practising three real audit tasks: vulnerability detection,
7
+ property discovery, and rule checking.
 
8
  author: "SmartAudit Team"
9
  license: MIT
10
 
 
36
  - id: task3_rule_checker
37
  name: Rule Checker
38
  difficulty: easy
39
+ status: active
40
  description: >
41
+ Given a natural-language property and a Solidity contract, identify the
42
+ function that violates that property. Partial credit for internal subfunctions.
43
  max_steps: 15
44
  reward_range: [-5.0, 5.0]
45
  grader: tasks/task3/grader.py
 
48
  observation_space:
49
  type: object
50
  properties:
51
+ task_id: {type: string}
52
+ contract_name: {type: string}
53
+ contract_description: {type: string}
54
+ available_actions: {type: array, items: {type: string}}
55
  last_action: {type: string, nullable: true}
56
  last_action_result: {type: string, nullable: true}
57
  step_count: {type: integer}
58
  cumulative_reward: {type: number}
59
  done: {type: boolean}
60
+ extra: {type: object}
61
 
62
  action_space:
63
  task1:
64
+ list_functions: {params: {}, reward: -0.05}
65
+ get_function_code: {params: {function_name: string}, reward: "+0.05 / -0.10"}
66
+ get_function_summary: {params: {function_name: string}, reward: "+0.03 / -0.05"}
67
+ get_file_metadata: {params: {}, reward: -0.04}
68
+ get_state_variable: {params: {variable_name: "string opt"}, reward: -0.05}
69
+ get_call_graph: {params: {}, reward: -0.08}
70
+ submit: {params: {function_name: string, vulnerability_type: string}, reward: "+5.0 / +1.0 / -1.5"}
 
 
71
  task2:
72
+ get_function_code: {params: {}, reward: -0.06}
73
+ get_function_natspec: {params: {}, reward: -0.08}
74
+ get_file_natspec: {params: {}, reward: -0.03}
75
+ get_related_functions: {params: {}, reward: -0.06}
76
+ get_io: {params: {}, reward: -0.04}
77
+ get_similar_rule: {params: {}, reward: -0.20}
78
+ submit_property: {params: {property: string}, reward: "0.0-5.0 keyword-weighted, one attempt"}
79
+ task3:
80
+ list_functions: {params: {}, reward: -0.05}
81
+ get_function_metadata: {params: {function_name: string}, reward: -0.05}
82
+ get_function_code: {params: {function_name: string}, reward: -0.10}
83
+ get_state_variable: {params: {variable_name: "string opt"}, reward: -0.05}
84
+ get_call_graph: {params: {}, reward: -0.08}
85
+ get_formalized_property: {params: {}, reward: -0.03}
86
+ submit_function: {params: {function_name: string}, reward: "+5.0 / +1.5 / -1.5, one attempt"}
87
 
88
  reward:
89
  type: shaped
90
+ all_tasks_shared:
91
+ repeated_query: -0.40
 
 
92
  task1_shaping:
 
 
93
  get_function_code_correct: +0.05
94
+ get_function_code_wrong: -0.10
95
  get_function_summary_correct: +0.03
96
+ get_function_summary_wrong: -0.05
 
 
 
97
  task1_terminal:
98
+ correct: +5.0
99
+ partial: +1.0
100
+ wrong: -1.5
 
 
 
 
 
 
 
 
101
  task2_terminal:
102
+ formula: "score * 5.0 where score = 0.70*(key_matches/key_total) + 0.30*(bonus_matches/bonus_total)"
103
+ range: [0.0, 5.0]
104
+ task3_terminal:
105
+ correct_function: +5.0
106
+ subfunction: +1.5
107
+ wrong_function: -1.5
108
 
109
  data:
110
  source: "Certora audited DeFi projects"
 
112
  num_contracts: 4
113
  num_vulnerable_functions: 8
114
  num_property_functions: 11
115
+ num_task3_episodes: 8
116
  vulnerability_types:
117
  - Reentrancy
118
  - Missing access control
 
125
 
126
  interface:
127
  http:
128
+ reset: "POST /reset"
129
+ step: "POST /step"
130
+ state: "GET /state"
131
+ tasks: "GET /tasks"
132
+ health: "GET /health"
133
+ action_space: "GET /action_space?task_id=<id>"
134
+ observation_space: "GET /observation_space"
135
  python:
136
+ reset: "env.reset(seed=None) -> ResetResult"
137
+ step: "env.step(action) -> StepResult"
138
+ state: "env.state() -> StateResult"
tasks/task3/__init__.py CHANGED
@@ -1,31 +1,5 @@
1
- """
2
- tasks/task3/__init__.py
3
- -----------------------
4
- Task 3: Rule Checker (PLACEHOLDER)
5
 
6
- TODO: Implement this task.
7
-
8
- Episode setup:
9
- - One Solidity file with at least one function breaking a given property
10
- - Agent is shown the property in natural English
11
-
12
- Actions (to implement):
13
- - get_formalized_property : -0.03
14
- - list_functions : -0.05
15
- - get_function_metadata : -0.05
16
- - get_function_code : -0.10
17
- - get_state_variables : -0.05
18
- - get_call_graph : -0.08
19
- - submit_function :
20
- - correct = +5.0
21
- - subfunction of target = +1.5
22
- - wrong = -1.5
23
- (ONE submission per episode)
24
-
25
- See README.md for full task specification.
26
- """
27
-
28
- # TODO: Task 3 – Rule Checker
29
- # from tasks.task3.environment import Task3Environment
30
-
31
- __all__: list = []
 
1
+ # Task 3: Rule Checker
2
+ from tasks.task3.environment import Task3Environment
3
+ from tasks.task3.grader import Task3Grader
 
4
 
5
+ __all__ = ["Task3Environment", "Task3Grader"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tasks/task3/environment.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ environment.py (Task 3 – Rule Checker)
3
+ -----------------------------------------
4
+ OpenEnv-compliant RL environment.
5
+
6
+ Episode setup
7
+ ─────────────
8
+ - A Solidity contract is selected that contains at least one function
9
+ violating a known property.
10
+ - The agent sees: contract description + the property in natural English.
11
+ - The agent must identify which function breaks that property.
12
+
13
+ Observation at reset
14
+ ────────────────────
15
+ extra.property_english – the violated property in plain English
16
+ extra.hint – instructions for the agent
17
+
18
+ Actions & rewards
19
+ ─────────────────
20
+ list_functions -0.05 see all function names
21
+ get_function_metadata -0.05 signature / visibility / modifiers / params
22
+ get_function_code -0.10 full Solidity source of any function
23
+ get_state_variables -0.05 list or inspect state variables
24
+ get_call_graph -0.08 function call graph
25
+ get_formalized_property -0.03 formal pre/post-condition version of property
26
+ submit_function terminal: +5.0 / +1.5 / -1.5 (ONE attempt)
27
+ repeated_query -0.40
28
+
29
+ Difficulty: Easy
30
+ The property text directly names the invariant broken; reading 2-3 functions
31
+ should let most agents identify the culprit quickly.
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ import random
37
+ from typing import Any, Dict, List, Optional, Set
38
+
39
+ from data.data_loader import (
40
+ load_contracts,
41
+ sample_task3_episode,
42
+ get_function_by_name,
43
+ get_state_variable_by_name,
44
+ list_function_names,
45
+ list_state_variable_names,
46
+ )
47
+ from env.base_env import BaseEnv
48
+ from env.schemas import (
49
+ Action,
50
+ ActionType,
51
+ Observation,
52
+ Reward,
53
+ ResetResult,
54
+ StateResult,
55
+ StepResult,
56
+ )
57
+ from tasks.task3.grader import Task3Grader
58
+
59
+ TASK_ID = "task3_rule_checker"
60
+ MAX_STEPS = 15
61
+
62
+ AVAILABLE_ACTIONS = [
63
+ ActionType.LIST_FUNCTIONS,
64
+ ActionType.GET_FUNCTION_METADATA,
65
+ ActionType.GET_FUNCTION_CODE,
66
+ ActionType.GET_STATE_VARIABLE,
67
+ ActionType.GET_CALL_GRAPH,
68
+ ActionType.GET_FORMALIZED_PROPERTY,
69
+ ActionType.SUBMIT_FUNCTION,
70
+ ]
71
+
72
+
73
+ class Task3Environment(BaseEnv):
74
+ """Task 3: Rule Checker β€” identify the function that violates a given property."""
75
+
76
+ def __init__(self, contracts_path: Optional[str] = None) -> None:
77
+ self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
78
+ self._rng = random.Random()
79
+
80
+ # Episode state β€” initialised by reset()
81
+ self._contract: Dict[str, Any] = {}
82
+ self._target_fn: Dict[str, Any] = {}
83
+ self._grader: Optional[Task3Grader] = None
84
+ self._step_count: int = 0
85
+ self._cum_reward: float = 0.0
86
+ self._done: bool = False
87
+ self._submitted: bool = False
88
+ self._query_hist: List[str] = []
89
+ self._seen: Set[str] = set()
90
+
91
+ # ── OpenEnv interface ─────────────────────────────────────────────────────
92
+
93
+ def reset(self, seed: Optional[int] = None) -> ResetResult:
94
+ if seed is not None:
95
+ self._rng.seed(seed)
96
+
97
+ self._contract, self._target_fn = sample_task3_episode(
98
+ self._contracts, self._rng
99
+ )
100
+ t3 = self._target_fn["task3"]
101
+ self._grader = Task3Grader(
102
+ target_function=self._target_fn["name"],
103
+ partial_credit_functions=t3.get("partial_credit_functions", []),
104
+ property_english=t3.get("property_english", ""),
105
+ )
106
+ self._step_count = 0
107
+ self._cum_reward = 0.0
108
+ self._done = False
109
+ self._submitted = False
110
+ self._query_hist = []
111
+ self._seen = set()
112
+
113
+ obs = self._build_obs(
114
+ last_action=None,
115
+ last_result=(
116
+ f"New episode started.\n"
117
+ f"Contract : {self._contract['contract_name']}\n\n"
118
+ f"Property : {t3['property_english']}\n\n"
119
+ f"Find the function in this contract that violates the property above.\n"
120
+ f"Use list_functions then get_function_code to investigate.\n"
121
+ f"Submit with submit_function, params={{\"function_name\": \"...\"}}.\n"
122
+ f"ONE submission allowed."
123
+ ),
124
+ )
125
+ return ResetResult(observation=obs, info={"task_id": TASK_ID})
126
+
127
+ def step(self, action: Action) -> StepResult:
128
+ if self._done:
129
+ raise RuntimeError("Episode is done. Call reset() to start a new episode.")
130
+
131
+ self._step_count += 1
132
+ result_text, reward = self._dispatch(action)
133
+ self._cum_reward += reward.value
134
+ self._query_hist.append(f"[{action.action_type}] β†’ {result_text[:100]}")
135
+
136
+ obs = self._build_obs(
137
+ last_action=action.action_type,
138
+ last_result=result_text,
139
+ )
140
+ return StepResult(
141
+ observation=obs,
142
+ reward=reward,
143
+ done=self._done,
144
+ info={"step": self._step_count, "cumulative_reward": self._cum_reward},
145
+ )
146
+
147
+ def state(self) -> StateResult:
148
+ return StateResult(
149
+ task_id=TASK_ID,
150
+ contract_name=self._contract.get("contract_name", ""),
151
+ target_function=self._target_fn.get("name"),
152
+ step_count=self._step_count,
153
+ cumulative_reward=self._cum_reward,
154
+ done=self._done,
155
+ query_history=list(self._query_hist),
156
+ )
157
+
158
+ # ── Internal helpers ──────────────────────────────────────────────────────
159
+
160
+ def _build_obs(self, last_action: Optional[str], last_result: str) -> Observation:
161
+ t3 = self._target_fn.get("task3", {})
162
+ return Observation(
163
+ task_id=TASK_ID,
164
+ contract_name=self._contract.get("contract_name", ""),
165
+ contract_description=self._contract.get("metadata", {}).get("description", ""),
166
+ available_actions=[a.value for a in AVAILABLE_ACTIONS],
167
+ last_action=last_action,
168
+ last_action_result=last_result,
169
+ step_count=self._step_count,
170
+ cumulative_reward=self._cum_reward,
171
+ done=self._done,
172
+ extra={
173
+ "property_english": t3.get("property_english", ""),
174
+ "solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
175
+ "hint": (
176
+ "Read the property, then inspect function code to find which one violates it. "
177
+ "Submit with: submit_function, params={'function_name': '<name>'}. "
178
+ "ONE submission per episode."
179
+ ),
180
+ },
181
+ )
182
+
183
+ def _qkey(self, at: str, params: Dict[str, Any]) -> str:
184
+ return f"{at}:{sorted(params.items())}"
185
+
186
+ def _is_repeated(self, key: str) -> bool:
187
+ if key in self._seen:
188
+ return True
189
+ self._seen.add(key)
190
+ return False
191
+
192
+ def _dispatch(self, action: Action) -> tuple[str, Reward]:
193
+ at = action.action_type
194
+ params = action.params
195
+ qkey = self._qkey(at, params)
196
+
197
+ # ── list_functions ────────────────────────────────────────────────────
198
+ if at == ActionType.LIST_FUNCTIONS:
199
+ if self._is_repeated(qkey):
200
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
201
+ names = list_function_names(self._contract)
202
+ return (
203
+ f"Functions in {self._contract['contract_name']}: {', '.join(names)}",
204
+ Reward(value=-0.05, reason="list_functions cost"),
205
+ )
206
+
207
+ # ── get_function_metadata ─────────────────────────────────────────────
208
+ if at == ActionType.GET_FUNCTION_METADATA:
209
+ fn_name = params.get("function_name", "")
210
+ if self._is_repeated(qkey):
211
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
212
+ fn = get_function_by_name(self._contract, fn_name)
213
+ if fn is None:
214
+ return (
215
+ f"Function '{fn_name}' not found. "
216
+ f"Available: {list_function_names(self._contract)}",
217
+ Reward(value=-0.05, reason="Unknown function"),
218
+ )
219
+ params_list = fn.get("parameters", [])
220
+ modifiers = fn.get("modifiers", [])
221
+ lines = [
222
+ f"Function : {fn.get('signature', fn_name)}",
223
+ f"Visibility : {fn.get('visibility', 'unknown')}",
224
+ f"Modifiers : {', '.join(modifiers) if modifiers else 'none'}",
225
+ ]
226
+ if params_list:
227
+ lines.append("Parameters :")
228
+ for p in params_list:
229
+ lines.append(f" {p['type']} {p['name']} β€” {p.get('description','')}")
230
+ else:
231
+ lines.append("Parameters : none")
232
+ lines.append(f"Returns : {fn.get('returns','') or 'void'}")
233
+ lines.append(f"Summary : {fn.get('comment','')}")
234
+ return "\n".join(lines), Reward(value=-0.05, reason="get_function_metadata cost")
235
+
236
+ # ── get_function_code ─────────────────────────────────────────────────
237
+ if at == ActionType.GET_FUNCTION_CODE:
238
+ fn_name = params.get("function_name", "")
239
+ if self._is_repeated(qkey):
240
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
241
+ fn = get_function_by_name(self._contract, fn_name)
242
+ if fn is None:
243
+ return (
244
+ f"Function '{fn_name}' not found. "
245
+ f"Available: {list_function_names(self._contract)}",
246
+ Reward(value=-0.10, reason="Unknown function β€” extra penalty"),
247
+ )
248
+ code = fn.get("code", "// no code available")
249
+ return (
250
+ f"// {fn_name}\n{code}",
251
+ Reward(value=-0.10, reason="get_function_code cost"),
252
+ )
253
+
254
+ # ── get_state_variables ───────────────────────────────────────────────
255
+ if at == ActionType.GET_STATE_VARIABLE:
256
+ var_name = params.get("variable_name", "")
257
+ if self._is_repeated(qkey):
258
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
259
+ if not var_name:
260
+ names = list_state_variable_names(self._contract)
261
+ return (
262
+ f"State variables: {', '.join(names)}",
263
+ Reward(value=-0.05, reason="Listed state variables"),
264
+ )
265
+ sv = get_state_variable_by_name(self._contract, var_name)
266
+ if sv is None:
267
+ return (
268
+ f"Variable '{var_name}' not found.",
269
+ Reward(value=-0.05, reason="Unknown state variable"),
270
+ )
271
+ return (
272
+ f"{sv['type']} {sv['visibility']} {sv['name']}: {sv.get('description','')}",
273
+ Reward(value=-0.05, reason="get_state_variable cost"),
274
+ )
275
+
276
+ # ── get_call_graph ────────────────────────────────────────────────────
277
+ if at == ActionType.GET_CALL_GRAPH:
278
+ if self._is_repeated(qkey):
279
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
280
+ cg = self._contract.get("call_graph", {})
281
+ cg_str = "; ".join(
282
+ f"{fn} β†’ [{', '.join(callees)}]" for fn, callees in cg.items()
283
+ )
284
+ return (
285
+ f"Call graph: {cg_str}",
286
+ Reward(value=-0.08, reason="get_call_graph cost"),
287
+ )
288
+
289
+ # ── get_formalized_property ───────────────────────────────────────────
290
+ if at == ActionType.GET_FORMALIZED_PROPERTY:
291
+ if self._is_repeated(qkey):
292
+ return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
293
+ formal = self._target_fn.get("task3", {}).get("property_formal", "")
294
+ if not formal:
295
+ formal = "No formal specification available for this property."
296
+ return (
297
+ f"Formal property:\n{formal}",
298
+ Reward(value=-0.03, reason="get_formalized_property cost"),
299
+ )
300
+
301
+ # ── submit_function ───────────────────────────────────────────────────
302
+ if at == ActionType.SUBMIT_FUNCTION:
303
+ if self._submitted:
304
+ return (
305
+ "❌ You have already submitted for this episode. "
306
+ "Only ONE submission is allowed.",
307
+ Reward(value=-1.0, reason="Second submit_function attempt", partial=False),
308
+ )
309
+ fn_name = params.get("function_name", "").strip()
310
+ if not fn_name:
311
+ return (
312
+ "submit_function requires 'function_name' in params.",
313
+ Reward(value=-0.5, reason="Malformed submission"),
314
+ )
315
+
316
+ self._submitted = True
317
+ self._done = True
318
+ score, reward_val = self._grader.grade_and_reward(fn_name)
319
+ correct = self._grader.get_canonical_answer()
320
+
321
+ if score >= 0.9:
322
+ msg = (
323
+ f"βœ… CORRECT! '{fn_name}' is the function that violates the property. "
324
+ f"Score: 1.0 β†’ Reward: +{reward_val:.1f}"
325
+ )
326
+ elif score >= 0.2:
327
+ msg = (
328
+ f"🟑 PARTIAL. '{fn_name}' is a subfunction of the target β€” "
329
+ f"closely related but not the primary rule-breaker. "
330
+ f"Score: 0.3 β†’ Reward: +{reward_val:.1f}. "
331
+ f"Correct answer: '{correct['target_function']}'."
332
+ )
333
+ else:
334
+ msg = (
335
+ f"❌ INCORRECT. '{fn_name}' does not violate the property. "
336
+ f"Score: 0.0 β†’ Reward: {reward_val:.1f}. "
337
+ f"Correct answer: '{correct['target_function']}'."
338
+ )
339
+
340
+ return msg, Reward(
341
+ value=reward_val,
342
+ reason=f"submit_function score={score:.1f}",
343
+ partial=False,
344
+ )
345
+
346
+ # ── unknown action ────────────────────────────────────────────────────
347
+ return (
348
+ f"Unknown action '{at}'. Valid: {[a.value for a in AVAILABLE_ACTIONS]}",
349
+ Reward(value=-0.10, reason="Unknown action"),
350
+ )
tasks/task3/grader.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ grader.py (Task 3 – Rule Checker)
3
+ ------------------------------------
4
+ Deterministic grader for function-identification submissions.
5
+
6
+ Score table
7
+ ───────────
8
+ 1.0 β†’ submitted function is the exact target (case-insensitive)
9
+ 0.3 β†’ submitted function is a direct internal subfunction of the target
10
+ (a contract-internal function called by the target in the call graph)
11
+ 0.0 β†’ anything else
12
+
13
+ Reward table (ONE submission per episode)
14
+ score 1.0 β†’ +5.0
15
+ score 0.3 β†’ +1.5
16
+ score 0.0 β†’ -1.5
17
+ """
18
+
19
+ from __future__ import annotations
20
+ from typing import Dict, List, Optional
21
+
22
+
23
+ class Task3Grader:
24
+ """
25
+ Grades a Task 3 submit_function submission.
26
+
27
+ Parameters
28
+ ----------
29
+ target_function : exact name of the rule-breaking function
30
+ partial_credit_functions: list of internal functions that get partial credit
31
+ (direct callees of the target that are contract functions)
32
+ property_english : the English property text (for feedback messages)
33
+ """
34
+
35
+ SCORE_CORRECT = 1.0
36
+ SCORE_PARTIAL = 0.3
37
+ SCORE_WRONG = 0.0
38
+
39
+ REWARD_CORRECT = 5.0
40
+ REWARD_PARTIAL = 1.5
41
+ REWARD_WRONG = -1.5
42
+
43
+ def __init__(
44
+ self,
45
+ target_function: str,
46
+ partial_credit_functions: List[str],
47
+ property_english: str = "",
48
+ ) -> None:
49
+ self.target_function = target_function.lower()
50
+ self.partial_credit_functions = [f.lower() for f in partial_credit_functions]
51
+ self.property_english = property_english
52
+
53
+ def grade(self, submitted_function: str) -> float:
54
+ """Returns deterministic score in {0.0, 0.3, 1.0}."""
55
+ norm = submitted_function.strip().lower()
56
+ if norm == self.target_function:
57
+ return self.SCORE_CORRECT
58
+ if norm in self.partial_credit_functions:
59
+ return self.SCORE_PARTIAL
60
+ return self.SCORE_WRONG
61
+
62
+ def reward_for_score(self, score: float) -> float:
63
+ """Maps score β†’ terminal reward."""
64
+ if score >= 0.9:
65
+ return self.REWARD_CORRECT
66
+ if score >= 0.2:
67
+ return self.REWARD_PARTIAL
68
+ return self.REWARD_WRONG
69
+
70
+ def grade_and_reward(self, submitted_function: str):
71
+ """Convenience: returns (score, reward)."""
72
+ score = self.grade(submitted_function)
73
+ return score, self.reward_for_score(score)
74
+
75
+ def get_canonical_answer(self) -> Dict[str, object]:
76
+ """For debugging / logging only β€” do not expose to the agent."""
77
+ return {
78
+ "target_function": self.target_function,
79
+ "partial_credit_functions": self.partial_credit_functions,
80
+ }
validate.py CHANGED
@@ -1,19 +1,18 @@
1
  """
2
  validate.py
3
  -----------
4
- Pre-submission validation. Checks all OpenEnv spec requirements.
5
-
6
- Usage: python validate.py
7
- Exit 0 = all checks pass. Exit 1 = one or more failures.
8
  """
9
 
10
- import json, sys, traceback
11
  from typing import Callable, List, Tuple
12
 
13
  PASS = "βœ…"; FAIL = "❌"
14
  results: List[Tuple[str, bool, str]] = []
15
 
16
- def check(name: str, fn: Callable[[], None]) -> None:
17
  try:
18
  fn(); results.append((name, True, ""))
19
  print(f" {PASS} {name}")
@@ -25,45 +24,41 @@ def check(name: str, fn: Callable[[], None]) -> None:
25
 
26
  def check_imports():
27
  from env.schemas import Observation, Action, Reward, StepResult, ResetResult, StateResult, ActionType
28
- from tasks.task1.environment import Task1Environment
29
- from tasks.task1.grader import Task1Grader
30
- from tasks.task2.environment import Task2Environment
31
- from tasks.task2.grader import Task2Grader
32
  from data.data_loader import load_contracts
33
 
34
  def check_openenv_yaml():
35
  import yaml
36
  with open("openenv.yaml") as f: spec = yaml.safe_load(f)
37
- assert "name" in spec
38
- assert len(spec.get("tasks", [])) >= 3
39
- assert "observation_space" in spec
40
- assert "action_space" in spec
41
- assert "reward" in spec
42
 
43
  def check_pydantic_models():
44
- from env.schemas import Observation, Action, ActionType, Reward, StepResult, ResetResult, StateResult
45
- obs = Observation(task_id="t1", contract_name="C", contract_description="D", available_actions=["submit"])
46
- assert obs.task_id == "t1"
47
- action = Action(action_type=ActionType.LIST_FUNCTIONS); assert action.action_type == ActionType.LIST_FUNCTIONS
48
- action2 = Action(action_type=ActionType.SUBMIT_PROPERTY); assert action2.action_type == ActionType.SUBMIT_PROPERTY
49
- reward = Reward(value=1.0, reason="test"); assert reward.value == 1.0
50
- step = StepResult(observation=obs, reward=reward, done=False); assert not step.done
51
- reset = ResetResult(observation=obs); assert reset.observation.task_id == "t1"
52
 
53
  def check_data_loading():
54
- from data.data_loader import load_contracts, get_all_vulnerable_entries, get_all_property_entries
55
- contracts = load_contracts()
56
- assert len(contracts) >= 1
57
- vuln_entries = get_all_vulnerable_entries(contracts)
58
- assert len(vuln_entries) >= 3, f"Need >=3 vulnerable fns, got {len(vuln_entries)}"
59
- prop_entries = get_all_property_entries(contracts)
60
- assert len(prop_entries) >= 3, f"Need >=3 property fns, got {len(prop_entries)}"
61
- for _, fn in prop_entries:
62
- p = fn["property"]
63
- assert "natural_language" in p
64
- assert "key_phrases" in p
65
- assert "bonus_phrases" in p
66
- assert len(p["key_phrases"]) >= 2
67
 
68
  def check_t1_env():
69
  from tasks.task1.environment import Task1Environment
@@ -71,9 +66,8 @@ def check_t1_env():
71
  env = Task1Environment()
72
  r = env.reset(seed=42); assert r.observation.task_id == "task1_vuln_detection"
73
  s = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
74
- assert isinstance(s.reward.value, float)
75
- assert s.observation.step_count == 1
76
- st = env.state(); assert st.target_function is not None
77
 
78
  def check_t2_env():
79
  from tasks.task2.environment import Task2Environment
@@ -82,178 +76,206 @@ def check_t2_env():
82
  r = env.reset(seed=42)
83
  assert r.observation.task_id == "task2_property_discovery"
84
  assert "target_function" in r.observation.extra
85
- # test each action type
86
  for at in [ActionType.GET_FUNCTION_CODE, ActionType.GET_FUNCTION_NATSPEC,
87
- ActionType.GET_FILE_NATSPEC, ActionType.GET_IO, ActionType.GET_RELATED_FUNCTIONS]:
88
- s = env.step(Action(action_type=at)); assert s.reward.value < 0
89
- s = env.step(Action(action_type=ActionType.GET_SIMILAR_RULE))
90
- assert s.reward.value == -0.20
91
 
92
- def check_t2_env_submit():
93
- from tasks.task2.environment import Task2Environment
94
- from data.data_loader import load_contracts, get_function_by_name
95
  from env.schemas import Action, ActionType
96
- env = Task2Environment()
97
  r = env.reset(seed=42)
98
- fn_name = r.observation.extra["target_function"]
99
- contract = r.observation.contract_name
100
- contracts = load_contracts()
101
- gt_text = ""
102
- for c in contracts:
103
- if c["contract_name"] == contract:
104
- fn = get_function_by_name(c, fn_name)
105
- if fn and fn.get("property"):
106
- gt_text = fn["property"]["natural_language"]
107
- result = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY, params={"property": gt_text}))
108
- assert result.done
109
- assert result.reward.value > 0, f"GT text should score >0, got {result.reward.value}"
110
-
111
- def check_t2_one_submit_only():
112
- from tasks.task2.environment import Task2Environment
113
  from env.schemas import Action, ActionType
114
- env = Task2Environment()
115
- env.reset(seed=5)
116
- env.step(Action(action_type=ActionType.SUBMIT_PROPERTY, params={"property": "test"}))
117
- # Second submit must either fail (episode done β†’ RuntimeError) or return negative reward
118
- try:
119
- s2 = env.step(Action(action_type=ActionType.SUBMIT_PROPERTY, params={"property": "test2"}))
120
- # If it doesn't raise, the reward must be negative
121
- assert s2.reward.value < 0, "Second submit should penalise"
122
- except RuntimeError:
123
- pass # expected
 
124
 
125
- def check_t1_grader():
126
- from tasks.task1.grader import Task1Grader
127
- cases = [
128
- ("withdraw", "Reentrancy vulnerability", "withdraw", "reentrancy", 1.0),
129
- ("withdraw", "Reentrancy vulnerability", "withdraw", "something else", 0.5),
130
- ("withdraw", "Reentrancy vulnerability", "deposit", "reentrancy", 0.0),
131
- ]
132
- for tf, issue, sf, sv, expected in cases:
133
- g = Task1Grader(tf, issue)
134
- score = g.grade_submission(sf, sv)
135
- assert 0.0 <= score <= 1.0
136
- assert abs(score - expected) < 0.01, f"Expected {expected}, got {score}"
137
 
138
- def check_t2_grader():
139
- from tasks.task2.grader import Task2Grader
140
- from data.data_loader import load_contracts, get_all_property_entries
141
- contracts = load_contracts()
142
- entries = get_all_property_entries(contracts)
143
- for contract, fn in entries:
144
- g = Task2Grader(fn["name"], fn["property"])
145
- # Ground truth must score β‰₯ 0.65
146
- gt_score = g.grade(fn["property"]["natural_language"])
147
- assert gt_score >= 0.65, f"{fn['name']}: gt_score={gt_score} < 0.65"
148
- # Empty must be 0.0
149
- assert g.grade("") == 0.0
150
- # Deterministic
151
- assert g.grade("test text") == g.grade("test text")
152
- # Score in [0,1]
153
- assert 0.0 <= gt_score <= 1.0
154
- # Reward maps correctly
155
- assert abs(g.reward_for_score(gt_score) - gt_score * 5.0) < 0.01
156
 
157
- def check_reward_shaping():
158
- from tasks.task2.environment import Task2Environment
159
  from env.schemas import Action, ActionType
160
- env = Task2Environment()
161
- env.reset(seed=1)
162
- rewards = {env.step(Action(action_type=at)).reward.value
163
- for at in [ActionType.GET_FUNCTION_CODE, ActionType.GET_FILE_NATSPEC, ActionType.GET_IO]}
164
- assert len(rewards) >= 2, f"Need multiple reward values, got {rewards}"
 
 
165
 
166
- def check_t1_episode_boundary():
167
- from tasks.task1.environment import Task1Environment
168
  from env.schemas import Action, ActionType
169
- env = Task1Environment()
170
- env.reset(seed=2)
171
- env.step(Action(action_type=ActionType.SUBMIT,
172
- params={"function_name": "withdraw", "vulnerability_type": "test"}))
 
 
 
 
 
 
 
173
  try:
174
  env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
175
  raise AssertionError("Should raise RuntimeError after done")
176
  except RuntimeError:
177
  pass
178
 
179
- def check_repeated_query_penalty():
180
- from tasks.task1.environment import Task1Environment
181
  from env.schemas import Action, ActionType
182
- env = Task1Environment(); env.reset(seed=3)
183
  env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
184
- r = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
185
- assert r.reward.value == -0.40
186
 
187
- def check_t2_repeated_penalty():
188
- from tasks.task2.environment import Task2Environment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  from env.schemas import Action, ActionType
190
- env = Task2Environment(); env.reset(seed=3)
191
- env.step(Action(action_type=ActionType.GET_FUNCTION_CODE))
192
- r = env.step(Action(action_type=ActionType.GET_FUNCTION_CODE))
193
- assert r.reward.value == -0.40
 
 
194
 
195
- def check_task_placeholders():
196
- from tasks.task3 import __all__ as t3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  def check_dockerfile():
199
  import os
200
  assert os.path.exists("Dockerfile")
201
- with open("Dockerfile") as f: c = f.read()
202
- assert "7860" in c
203
- assert "uvicorn" in c or "CMD" in c
204
 
205
  def check_inference_script():
206
  import os
207
  assert os.path.exists("inference.py")
208
- with open("inference.py") as f: c = f.read()
209
- assert "HF_TOKEN" in c
210
- assert "API_BASE_URL" in c
211
- assert "MODEL_NAME" in c
212
- assert "task2" in c.lower() or "Task2" in c or "TASK 2" in c
213
 
214
  def check_baseline_json():
215
  import os
216
  if not os.path.exists("baseline_scores.json"): return
217
- with open("baseline_scores.json") as f: data = json.load(f)
218
- assert "tasks" in data
219
- for t in data["tasks"]:
220
  assert 0.0 <= t["avg_grader_score"] <= 1.0
221
 
222
- def check_similar_rule_lookup():
223
- from data.data_loader import load_contracts, get_similar_rule
224
- contracts = load_contracts()
225
- sr = get_similar_rule(contracts, "SimpleVault", "withdraw")
226
- assert sr is not None, "similar_rule should exist for withdraw"
227
- assert "property_hint" in sr
228
- assert "contract_name" in sr
229
-
230
  # ── Runner ────────────────────────────────────────────────────────────────────
231
 
232
  ALL_CHECKS = [
233
- ("Python imports (T1 + T2)", check_imports),
234
- ("openenv.yaml format", check_openenv_yaml),
235
- ("Pydantic models (incl T2 actions)", check_pydantic_models),
236
- ("Dataset: vuln + property entries", check_data_loading),
237
- ("Task 1: reset / step / state", check_t1_env),
238
- ("Task 2: reset + all 6 browse actions",check_t2_env),
239
- ("Task 2: submit_property scores > 0", check_t2_env_submit),
240
- ("Task 2: one submit only", check_t2_one_submit_only),
241
- ("Task 1 grader: 0/0.5/1.0 rubric", check_t1_grader),
242
- ("Task 2 grader: all 11 properties", check_t2_grader),
243
- ("Reward shaping (multi-value)", check_reward_shaping),
244
- ("T1 episode boundary", check_t1_episode_boundary),
245
- ("T1 repeated query penalty (-0.40)", check_repeated_query_penalty),
246
- ("T2 repeated query penalty (-0.40)", check_t2_repeated_penalty),
247
- ("Task 3 placeholder exists", check_task_placeholders),
 
 
 
 
 
248
  ("Dockerfile + port 7860", check_dockerfile),
249
- ("inference.py: creds + Task 2 code", check_inference_script),
250
  ("baseline_scores.json schema", check_baseline_json),
251
- ("similar_rule data lookup", check_similar_rule_lookup),
252
  ]
253
 
254
  def main():
255
  print("=" * 64)
256
- print("OpenEnv Pre-Submission Validation (Task 1 + Task 2)")
257
  print("=" * 64)
258
  print()
259
  for name, fn in ALL_CHECKS:
@@ -270,7 +292,7 @@ def main():
270
  print("\nFailed checks:")
271
  for n, m in failed:
272
  print(f" {FAIL} {n}: {m}")
273
- print("\n❌ VALIDATION FAILED β€” fix the issues above before submitting.")
274
  sys.exit(1)
275
  else:
276
  print("\nβœ… ALL CHECKS PASSED β€” ready to submit!")
 
1
  """
2
  validate.py
3
  -----------
4
+ Pre-submission validation β€” 24 checks across all three tasks.
5
+ Usage: python validate.py
6
+ Exit 0 = all pass. Exit 1 = failures.
 
7
  """
8
 
9
+ import json, sys
10
  from typing import Callable, List, Tuple
11
 
12
  PASS = "βœ…"; FAIL = "❌"
13
  results: List[Tuple[str, bool, str]] = []
14
 
15
+ def check(name: str, fn: Callable) -> None:
16
  try:
17
  fn(); results.append((name, True, ""))
18
  print(f" {PASS} {name}")
 
24
 
25
  def check_imports():
26
  from env.schemas import Observation, Action, Reward, StepResult, ResetResult, StateResult, ActionType
27
+ from tasks.task1.environment import Task1Environment; from tasks.task1.grader import Task1Grader
28
+ from tasks.task2.environment import Task2Environment; from tasks.task2.grader import Task2Grader
29
+ from tasks.task3.environment import Task3Environment; from tasks.task3.grader import Task3Grader
 
30
  from data.data_loader import load_contracts
31
 
32
  def check_openenv_yaml():
33
  import yaml
34
  with open("openenv.yaml") as f: spec = yaml.safe_load(f)
35
+ assert "name" in spec and len(spec.get("tasks", [])) >= 3
36
+ assert "observation_space" in spec and "action_space" in spec and "reward" in spec
37
+ tasks = spec["tasks"]
38
+ active = [t for t in tasks if t.get("status") == "active"]
39
+ assert len(active) >= 2, f"Expected >=2 active tasks, got {len(active)}"
40
 
41
  def check_pydantic_models():
42
+ from env.schemas import Observation, Action, ActionType, Reward, StepResult, ResetResult
43
+ obs = Observation(task_id="t", contract_name="C", contract_description="D", available_actions=[])
44
+ for at in [ActionType.LIST_FUNCTIONS, ActionType.SUBMIT_PROPERTY,
45
+ ActionType.GET_FORMALIZED_PROPERTY, ActionType.SUBMIT_FUNCTION]:
46
+ Action(action_type=at)
47
+ Reward(value=-1.5, reason="test")
48
+ StepResult(observation=obs, reward=Reward(value=0, reason=""), done=False)
 
49
 
50
  def check_data_loading():
51
+ from data.data_loader import (load_contracts, get_all_vulnerable_entries,
52
+ get_all_property_entries, get_all_task3_entries)
53
+ c = load_contracts()
54
+ assert len(get_all_vulnerable_entries(c)) >= 3
55
+ assert len(get_all_property_entries(c)) >= 3
56
+ entries = get_all_task3_entries(c)
57
+ assert len(entries) >= 3, f"Need >=3 task3 entries, got {len(entries)}"
58
+ for _, fn in entries:
59
+ t3 = fn.get("task3", {})
60
+ assert t3.get("property_english"), f"{fn['name']} missing property_english"
61
+ assert t3.get("property_formal"), f"{fn['name']} missing property_formal"
 
 
62
 
63
  def check_t1_env():
64
  from tasks.task1.environment import Task1Environment
 
66
  env = Task1Environment()
67
  r = env.reset(seed=42); assert r.observation.task_id == "task1_vuln_detection"
68
  s = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
69
+ assert s.reward.value == -0.05 and s.observation.step_count == 1
70
+ assert env.state().target_function is not None
 
71
 
72
  def check_t2_env():
73
  from tasks.task2.environment import Task2Environment
 
76
  r = env.reset(seed=42)
77
  assert r.observation.task_id == "task2_property_discovery"
78
  assert "target_function" in r.observation.extra
 
79
  for at in [ActionType.GET_FUNCTION_CODE, ActionType.GET_FUNCTION_NATSPEC,
80
+ ActionType.GET_FILE_NATSPEC, ActionType.GET_IO,
81
+ ActionType.GET_RELATED_FUNCTIONS, ActionType.GET_SIMILAR_RULE]:
82
+ env.step(Action(action_type=at))
 
83
 
84
+ def check_t3_env():
85
+ from tasks.task3.environment import Task3Environment
 
86
  from env.schemas import Action, ActionType
87
+ env = Task3Environment()
88
  r = env.reset(seed=42)
89
+ assert r.observation.task_id == "task3_rule_checker"
90
+ assert "property_english" in r.observation.extra
91
+ prop = r.observation.extra["property_english"]
92
+ assert len(prop) > 10, "property_english too short"
93
+ for at in [ActionType.LIST_FUNCTIONS, ActionType.GET_FORMALIZED_PROPERTY,
94
+ ActionType.GET_CALL_GRAPH, ActionType.GET_STATE_VARIABLE]:
95
+ s = env.step(Action(action_type=at))
96
+ assert s.reward.value < 0, f"{at.value} should have negative shaping reward"
97
+
98
+ def check_t3_action_costs():
99
+ from tasks.task3.environment import Task3Environment
 
 
 
 
100
  from env.schemas import Action, ActionType
101
+ env = Task3Environment(); env.reset(seed=42)
102
+ costs = {
103
+ ActionType.GET_FORMALIZED_PROPERTY: -0.03,
104
+ ActionType.LIST_FUNCTIONS: -0.05,
105
+ ActionType.GET_CALL_GRAPH: -0.08,
106
+ }
107
+ for at, expected in costs.items():
108
+ e2 = Task3Environment(); e2.reset(seed=42)
109
+ s = e2.step(Action(action_type=at))
110
+ assert abs(s.reward.value - expected) < 0.001, \
111
+ f"{at.value}: expected {expected}, got {s.reward.value}"
112
 
113
+ def check_t3_function_metadata():
114
+ from tasks.task3.environment import Task3Environment
115
+ from env.schemas import Action, ActionType
116
+ env = Task3Environment(); env.reset(seed=43)
117
+ s = env.step(Action(action_type=ActionType.GET_FUNCTION_METADATA,
118
+ params={"function_name": "withdraw"}))
119
+ assert "Visibility" in s.observation.last_action_result
120
+ assert s.reward.value == -0.05
 
 
 
 
121
 
122
+ def check_t3_submit_correct():
123
+ from tasks.task3.environment import Task3Environment
124
+ from env.schemas import Action, ActionType
125
+ env = Task3Environment(); env.reset(seed=42)
126
+ target = env.state().target_function
127
+ s = env.step(Action(action_type=ActionType.SUBMIT_FUNCTION,
128
+ params={"function_name": target}))
129
+ assert s.done and s.reward.value == 5.0, \
130
+ f"Expected reward=5.0, got {s.reward.value}"
 
 
 
 
 
 
 
 
 
131
 
132
+ def check_t3_submit_subfunction():
133
+ from tasks.task3.environment import Task3Environment
134
  from env.schemas import Action, ActionType
135
+ # seed 45 β†’ bid with subfunction getPrice
136
+ env = Task3Environment(); env.reset(seed=45)
137
+ assert env.state().target_function == "bid"
138
+ s = env.step(Action(action_type=ActionType.SUBMIT_FUNCTION,
139
+ params={"function_name": "getPrice"}))
140
+ assert s.done and s.reward.value == 1.5, \
141
+ f"Expected partial reward=1.5, got {s.reward.value}"
142
 
143
+ def check_t3_submit_wrong():
144
+ from tasks.task3.environment import Task3Environment
145
  from env.schemas import Action, ActionType
146
+ env = Task3Environment(); env.reset(seed=42)
147
+ s = env.step(Action(action_type=ActionType.SUBMIT_FUNCTION,
148
+ params={"function_name": "constructor"}))
149
+ assert s.done and s.reward.value == -1.5
150
+
151
+ def check_t3_one_submit_only():
152
+ from tasks.task3.environment import Task3Environment
153
+ from env.schemas import Action, ActionType
154
+ env = Task3Environment(); env.reset(seed=42)
155
+ env.step(Action(action_type=ActionType.SUBMIT_FUNCTION,
156
+ params={"function_name": "deposit"}))
157
  try:
158
  env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
159
  raise AssertionError("Should raise RuntimeError after done")
160
  except RuntimeError:
161
  pass
162
 
163
+ def check_t3_repeated_penalty():
164
+ from tasks.task3.environment import Task3Environment
165
  from env.schemas import Action, ActionType
166
+ env = Task3Environment(); env.reset(seed=42)
167
  env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
168
+ s = env.step(Action(action_type=ActionType.LIST_FUNCTIONS))
169
+ assert s.reward.value == -0.40
170
 
171
+ def check_t1_grader():
172
+ from tasks.task1.grader import Task1Grader
173
+ g = Task1Grader("withdraw", "Reentrancy vulnerability")
174
+ assert g.grade_submission("withdraw", "reentrancy") == 1.0
175
+ assert g.grade_submission("withdraw", "vague") == 0.5
176
+ assert g.grade_submission("deposit", "reentrancy") == 0.0
177
+
178
+ def check_t2_grader():
179
+ from tasks.task2.grader import Task2Grader
180
+ from data.data_loader import load_contracts, get_all_property_entries
181
+ for c, fn in get_all_property_entries(load_contracts()):
182
+ g = Task2Grader(fn["name"], fn["property"])
183
+ assert g.grade(fn["property"]["natural_language"]) >= 0.65
184
+ assert g.grade("") == 0.0
185
+ s = g.grade("test"); assert s == g.grade("test") # deterministic
186
+
187
+ def check_t3_grader():
188
+ from tasks.task3.grader import Task3Grader
189
+ g = Task3Grader("withdraw", ["deposit"], "some rule")
190
+ assert g.grade("withdraw") == 1.0
191
+ assert g.grade("WITHDRAW") == 1.0 # case-insensitive
192
+ assert g.grade("deposit") == 0.3
193
+ assert g.grade("constructor") == 0.0
194
+ s, r = g.grade_and_reward("withdraw"); assert s == 1.0 and r == 5.0
195
+ s, r = g.grade_and_reward("deposit"); assert s == 0.3 and r == 1.5
196
+ s, r = g.grade_and_reward("other"); assert s == 0.0 and r == -1.5
197
+
198
+ def check_reward_shaping():
199
+ from tasks.task3.environment import Task3Environment
200
  from env.schemas import Action, ActionType
201
+ env = Task3Environment(); env.reset(seed=1)
202
+ rewards = {env.step(Action(action_type=at)).reward.value
203
+ for at in [ActionType.LIST_FUNCTIONS,
204
+ ActionType.GET_FORMALIZED_PROPERTY,
205
+ ActionType.GET_CALL_GRAPH]}
206
+ assert len(rewards) >= 2
207
 
208
+ def check_app_imports():
209
+ from app import app
210
+ from fastapi.testclient import TestClient
211
+ client = TestClient(app)
212
+ r = client.get("/health"); assert r.status_code == 200
213
+ tasks = client.get("/tasks").json()["tasks"]
214
+ active = [t for t in tasks if t["status"] == "active"]
215
+ assert len(active) == 3, f"Expected 3 active tasks, got {len(active)}: {active}"
216
+
217
+ def check_t3_http_reset():
218
+ from app import app
219
+ from fastapi.testclient import TestClient
220
+ client = TestClient(app)
221
+ r = client.post("/reset", json={"task_id": "task3_rule_checker", "seed": 42})
222
+ assert r.status_code == 200
223
+ obs = r.json()["observation"]
224
+ assert obs["task_id"] == "task3_rule_checker"
225
+ assert "property_english" in obs["extra"]
226
 
227
  def check_dockerfile():
228
  import os
229
  assert os.path.exists("Dockerfile")
230
+ c = open("Dockerfile").read()
231
+ assert "7860" in c and ("uvicorn" in c or "CMD" in c)
 
232
 
233
  def check_inference_script():
234
  import os
235
  assert os.path.exists("inference.py")
236
+ c = open("inference.py").read()
237
+ assert "HF_TOKEN" in c and "API_BASE_URL" in c and "MODEL_NAME" in c
238
+ assert "Task3Environment" in c or "run_task3" in c
239
+ assert "submit_function" in c
 
240
 
241
  def check_baseline_json():
242
  import os
243
  if not os.path.exists("baseline_scores.json"): return
244
+ data = json.load(open("baseline_scores.json"))
245
+ for t in data.get("tasks", []):
 
246
  assert 0.0 <= t["avg_grader_score"] <= 1.0
247
 
 
 
 
 
 
 
 
 
248
  # ── Runner ────────────────────────────────────────────────────────────────────
249
 
250
  ALL_CHECKS = [
251
+ ("Python imports (T1+T2+T3)", check_imports),
252
+ ("openenv.yaml: 3 tasks, β‰₯2 active", check_openenv_yaml),
253
+ ("Pydantic models (all ActionTypes)", check_pydantic_models),
254
+ ("Dataset: vuln+property+task3 entries",check_data_loading),
255
+ ("T1 env: reset/step/state", check_t1_env),
256
+ ("T2 env: reset + 6 browse actions", check_t2_env),
257
+ ("T3 env: reset + browse actions", check_t3_env),
258
+ ("T3 action costs (formalized -0.03)", check_t3_action_costs),
259
+ ("T3 get_function_metadata", check_t3_function_metadata),
260
+ ("T3 submit correct β†’ +5.0", check_t3_submit_correct),
261
+ ("T3 submit subfunction β†’ +1.5", check_t3_submit_subfunction),
262
+ ("T3 submit wrong β†’ -1.5", check_t3_submit_wrong),
263
+ ("T3 one submit per episode", check_t3_one_submit_only),
264
+ ("T3 repeated query β†’ -0.40", check_t3_repeated_penalty),
265
+ ("T1 grader: 0/0.5/1.0 rubric", check_t1_grader),
266
+ ("T2 grader: all 11 properties", check_t2_grader),
267
+ ("T3 grader: 1.0/0.3/0.0 + case-ins.", check_t3_grader),
268
+ ("Reward shaping non-binary (T3)", check_reward_shaping),
269
+ ("FastAPI: 3 active tasks", check_app_imports),
270
+ ("FastAPI: T3 reset endpoint", check_t3_http_reset),
271
  ("Dockerfile + port 7860", check_dockerfile),
272
+ ("inference.py: T3 code present", check_inference_script),
273
  ("baseline_scores.json schema", check_baseline_json),
 
274
  ]
275
 
276
  def main():
277
  print("=" * 64)
278
+ print("OpenEnv Pre-Submission Validation (Task 1 + 2 + 3)")
279
  print("=" * 64)
280
  print()
281
  for name, fn in ALL_CHECKS:
 
292
  print("\nFailed checks:")
293
  for n, m in failed:
294
  print(f" {FAIL} {n}: {m}")
295
+ print("\n❌ VALIDATION FAILED")
296
  sys.exit(1)
297
  else:
298
  print("\nβœ… ALL CHECKS PASSED β€” ready to submit!")