ajaxwin commited on
Commit
c6002b4
Β·
1 Parent(s): 1248d28

fix: Update output instructions in prompts.py to enforce case sensitivity and structure, inference.py made DRY, PROJECT works

Browse files
Files changed (2) hide show
  1. inference.py +164 -233
  2. utils/prompts.py +6 -3
inference.py CHANGED
@@ -7,7 +7,7 @@ Implements agents for all three tasks using the Groq client.
7
  Emits mandatory structured stdout in the OpenEnv format.
8
 
9
  MANDATORY ENV VARS:
10
- GROQ_API_KEY Groq API key (required)
11
  MODEL_NAME Model identifier (default: openai/gpt-oss-20b)
12
 
13
  MANDATORY STDOUT FORMAT (per episode):
@@ -26,15 +26,14 @@ import asyncio
26
  import json
27
  import os
28
  import sys
29
- import time
30
- from typing import Any, Dict, List, Optional
31
 
32
  from openai import AsyncOpenAI
 
33
 
34
  from server import Task1Environment, Task2Environment, Task3Environment
35
  from env.schemas import Action, ActionType
36
  from utils import T1_SYSTEM, T2_SYSTEM, T3_SYSTEM
37
- from dotenv import load_dotenv
38
 
39
  # ─────────────────────────────────────────────────────────────────────────────
40
  # Configuration
@@ -47,24 +46,32 @@ HF_TOKEN = os.getenv("HF_TOKEN", "")
47
 
48
  if not HF_TOKEN:
49
  raise RuntimeError("HF_TOKEN environment variable not set")
50
-
51
  client = AsyncOpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
52
 
 
 
 
 
53
  # Benchmark / environment identifier (constant for this env)
54
  ENV_BENCHMARK = "smart-contract-audit"
55
 
56
  # Episodes per task
57
- NUM_EPISODES = 3
58
  SEED_BASE = 42
59
 
60
- # Max steps per task
61
- MAX_STEPS_T1 = 15
62
- MAX_STEPS_T2 = 15
63
- MAX_STEPS_T3 = 15
64
 
65
  # A grader_score >= this is considered a "success" for the [END] line
66
  SUCCESS_SCORE_THRESHOLD = 0.5
67
 
 
 
 
 
 
 
68
  # ─────────────────────────────────────────────────────────────────────────────
69
  # Unified LLM call function
70
  # ─────────────────────────────────────────────────────────────────────────────
@@ -79,12 +86,19 @@ async def get_llm_response(
79
  Returns the response content as a string.
80
  Raises an exception on failure (to be caught by the caller).
81
  """
82
- completion = await client.chat.completions.create(
83
- model=MODEL_NAME,
84
- messages=messages, # type: ignore
85
- )
86
- return completion.choices[0].message.content.strip() # type: ignore
87
-
 
 
 
 
 
 
 
88
 
89
  # ─────────────────────────────────────────────────────────────────────────────
90
  # Mandatory stdout helpers
@@ -95,12 +109,7 @@ def log_start(task: str, env: str, model: str) -> None:
95
  print(f"[START] task={task} env={env} model={model}", flush=True)
96
 
97
 
98
- def log_step(
99
- step: int,
100
- action: str,
101
- reward: float,
102
- done: bool,
103
- error: Optional[str] = None,
104
  ) -> None:
105
  """Emit a [STEP] line β€” one per env.step() call."""
106
  error_val = error if error else "null"
@@ -111,12 +120,7 @@ def log_step(
111
  )
112
 
113
 
114
- def log_end(
115
- success: bool,
116
- steps: int,
117
- score: float,
118
- rewards: List[float],
119
- ) -> None:
120
  """Emit the [END] line β€” one per episode, always emitted."""
121
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
122
  print(
@@ -125,163 +129,112 @@ def log_end(
125
  flush=True,
126
  )
127
 
128
-
129
  # ─────────────────────────────────────────────────────────────────────────────
130
- # Task 1 β€” Targeted Vulnerability Detection
131
  # ─────────────────────────────────────────────────────────────────────────────
132
 
133
- def _t1_user_msg(obs: Dict[str, Any]) -> str:
134
- return (
135
- f"Last action : {obs['last_action'] or 'None'}\n"
136
- f"Last result : {obs['last_action_result'] or 'Episode just started.'}"
137
- )
138
-
139
-
140
- async def _run_t1_episode(env: Task1Environment, seed: int, ep_num: int) -> Dict[str, Any]:
141
- """Run one Task 1 episode; emit [START]/[STEP]/[END]."""
142
- r = env.reset(seed=seed)
 
 
 
 
 
 
 
143
  obs = r.observation.model_dump()
144
 
145
- log_start(task="task1_vuln_detection", env=ENV_BENCHMARK, model=MODEL_NAME)
146
 
147
  messages: List[Dict[str, str]] = [
148
- {"role": "system", "content": T1_SYSTEM}
149
  ]
150
  step_rewards: List[float] = []
151
- grader_score = 0.0
152
- steps_taken = 0
153
  error_msg: Optional[str] = None
154
 
155
  try:
156
- for step in range(1, MAX_STEPS_T1 + 1):
157
- messages.append({"role": "user", "content": _t1_user_msg(obs)})
158
  try:
159
- raw = await get_llm_response(messages, max_tokens=200, temperature=0.0)
160
  error_msg = None
161
  except Exception as e:
162
  raw = ""
163
  error_msg = str(e)[:80]
164
- print(f"[DEBUG] T1 LLM error ep={ep_num} step={step}: {e}", file=sys.stderr)
165
 
166
  try:
167
  parsed = json.loads(raw)
168
- at = ActionType(parsed["action"])
169
  params = parsed.get("params", {})
170
- except Exception:
171
- at, params = ActionType.LIST_FUNCTIONS, {}
 
172
 
173
  messages.append({"role": "assistant", "content": raw})
174
  result = env.step(Action(action_type=at, params=params))
175
- obs = result.observation.model_dump()
176
- r_val = result.reward.value
177
- done = result.done
178
 
179
  step_rewards.append(r_val)
180
  steps_taken = step
 
181
  log_step(step=step, action=at.value, reward=r_val, done=done, error=error_msg)
182
 
183
  if done:
184
  grader_score = r_val
185
  break
186
 
187
- time.sleep(0.3)
188
 
189
  finally:
190
  success = grader_score >= SUCCESS_SCORE_THRESHOLD
191
  log_end(success=success, steps=steps_taken, score=grader_score, rewards=step_rewards)
192
 
193
- return {
194
- "episode": ep_num,
195
- "seed": seed,
196
- "contract": obs["contract_name"],
197
- "grader_score": grader_score
198
  }
 
 
199
 
 
200
 
201
  # ─────────────────────────────────────────────────────────────────────────────
202
- # Task 2 β€” Property Discovery
203
  # ─────────────────────────────────────────────────────────────────────────────
204
 
 
 
 
 
 
205
 
206
- def _t2_user_msg(obs: Dict[str, Any]) -> str:
207
  extra = obs.get("extra", {})
208
  return (
209
  f"Target Function : {extra.get('target_function', '?')} "
210
- # f"({extra.get('target_signature', '')})\n"
211
  f"Last action : {obs['last_action'] or 'None'}\n"
212
  f"Last result :\n{obs['last_action_result'] or 'Episode just started.'}"
213
  )
214
 
 
 
215
 
216
- async def _run_t2_episode(env: Task2Environment, seed: int, ep_num: int) -> Dict[str, Any]:
217
- """Run one Task 2 episode; emit [START]/[STEP]/[END]."""
218
- r = env.reset(seed=seed)
219
- obs = r.observation.model_dump()
220
- fn = obs["extra"].get("target_function", "?")
221
-
222
- log_start(task="task2_property_discovery", env=ENV_BENCHMARK, model=MODEL_NAME)
223
-
224
- messages: List[Dict[str, str]] = [
225
- {"role": "system", "content": T2_SYSTEM}
226
- ]
227
- step_rewards: List[float] = []
228
- grader_score = 0.0
229
- steps_taken = 0
230
- error_msg: Optional[str] = None
231
-
232
- try:
233
- for step in range(1, MAX_STEPS_T2 + 1):
234
- messages.append({"role": "user", "content": _t2_user_msg(obs)})
235
- try:
236
- raw = await get_llm_response(messages, max_tokens=400, temperature=0.0)
237
- error_msg = None
238
- except Exception as e:
239
- raw = ""
240
- error_msg = str(e)[:80]
241
- print(f"[DEBUG] T2 LLM error ep={ep_num} step={step}: {e}", file=sys.stderr)
242
-
243
- try:
244
- parsed = json.loads(raw)
245
- at = ActionType(parsed["action"])
246
- params = parsed.get("params", {})
247
- except Exception:
248
- at, params = ActionType.GET_FUNCTION_CODE, {}
249
-
250
- messages.append({"role": "assistant", "content": raw})
251
- result = env.step(Action(action_type=at, params=params))
252
- obs = result.observation.model_dump()
253
- r_val = result.reward.value
254
- done = result.done
255
-
256
- step_rewards.append(r_val)
257
- steps_taken = step
258
- log_step(step=step, action=at.value, reward=r_val, done=done, error=error_msg)
259
-
260
- if done:
261
- grader_score = r_val
262
- break
263
-
264
- time.sleep(0.3)
265
-
266
- finally:
267
- success = grader_score >= SUCCESS_SCORE_THRESHOLD
268
- log_end(success=success, steps=steps_taken, score=grader_score, rewards=step_rewards)
269
-
270
- return {
271
- "episode": ep_num,
272
- "seed": seed,
273
- "contract": obs["contract_name"],
274
- "function": fn,
275
- "grader_score": grader_score
276
- }
277
-
278
-
279
- # ─────────────────────────────────────────────────────────────────────────────
280
- # Task 3 β€” Rule Checker
281
- # ─────────────────────────────────────────────────────────────────────────────
282
-
283
-
284
- def _t3_user_msg(obs: Dict[str, Any]) -> str:
285
  extra = obs.get("extra", {})
286
  return (
287
  f"Verify Property : {extra.get('property_english', '(none)')}\n"
@@ -289,139 +242,117 @@ def _t3_user_msg(obs: Dict[str, Any]) -> str:
289
  f"Last result :\n{obs['last_action_result'] or 'Episode just started.'}"
290
  )
291
 
 
 
 
292
 
293
- async def _run_t3_episode(env: Task3Environment, seed: int, ep_num: int) -> Dict[str, Any]:
294
- """Run one Task 3 episode; emit [START]/[STEP]/[END]."""
295
- r = env.reset(seed=seed)
296
- obs = r.observation.model_dump()
297
-
298
- log_start(task="task3_rule_checker", env=ENV_BENCHMARK, model=MODEL_NAME)
299
-
300
- messages: List[Dict[str, str]] = [
301
- {"role": "system", "content": T3_SYSTEM}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  ]
303
- step_rewards: List[float] = []
304
- grader_score = 0.0
305
- steps_taken = 0
306
- error_msg: Optional[str] = None
307
-
308
- try:
309
- for step in range(1, MAX_STEPS_T3 + 1):
310
- messages.append({"role": "user", "content": _t3_user_msg(obs)})
311
- try:
312
- raw = await get_llm_response(messages, max_tokens=200, temperature=0.0)
313
- error_msg = None
314
- except Exception as e:
315
- raw = ""
316
- error_msg = str(e)[:80]
317
- print(f"[DEBUG] T3 LLM error ep={ep_num} step={step}: {e}", file=sys.stderr)
318
-
319
- try:
320
- parsed = json.loads(raw)
321
- at = ActionType(parsed["action"])
322
- params = parsed.get("params", {})
323
- except Exception:
324
- at, params = ActionType.LIST_FUNCTIONS, {}
325
-
326
- messages.append({"role": "assistant", "content": raw})
327
- result = env.step(Action(action_type=at, params=params))
328
- obs = result.observation.model_dump()
329
- r_val = result.reward.value
330
- done = result.done
331
-
332
- step_rewards.append(r_val)
333
- steps_taken = step
334
- log_step(step=step, action=at.value, reward=r_val, done=done, error=error_msg)
335
-
336
- if done:
337
- grader_score = r_val
338
- break
339
-
340
- time.sleep(0.3)
341
-
342
- finally:
343
- success = grader_score >= SUCCESS_SCORE_THRESHOLD
344
- log_end(success=success, steps=steps_taken, score=grader_score, rewards=step_rewards)
345
 
 
346
  return {
347
- "episode": ep_num,
348
- "seed": seed,
349
- "contract": obs["contract_name"],
350
- "grader_score": grader_score
 
 
351
  }
352
 
353
-
354
  # ─────────────────────────────────────────────────────────────────────────────
355
- # Task runners
356
  # ─────────────────────────────────────────────────────────────────────────────
357
 
358
  async def run_task1(n: int = NUM_EPISODES) -> Dict[str, Any]:
359
- print("\n" + "="*60, flush=True)
360
- print("TASK 1: Targeted Vulnerability Detection", flush=True)
361
- print("="*60, flush=True)
362
- env = Task1Environment()
363
- episodes = [await _run_t1_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
364
- avg_s = sum(e["grader_score"] for e in episodes) / n
365
- print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
366
- return {
367
- "task_id": "task1_vuln_detection", "name": "Targeted Vulnerability Detection",
368
- "status": "active", "num_episodes": n, "episodes": episodes,
369
- "avg_grader_score": avg_s
370
- }
371
-
372
 
373
  async def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
374
- print("\n" + "="*60, flush=True)
375
- print("TASK 2: Property Discovery", flush=True)
376
- print("="*60, flush=True)
377
- env = Task2Environment()
378
- episodes = [await _run_t2_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
379
- avg_s = sum(e["grader_score"] for e in episodes) / n
380
- print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
381
- return {
382
- "task_id": "task2_property_discovery", "name": "Property Discovery",
383
- "status": "active", "num_episodes": n, "episodes": episodes,
384
- "avg_grader_score": avg_s
385
- }
386
-
387
 
388
  async def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]:
389
- print("\n" + "="*60, flush=True)
390
- print("TASK 3: Rule Checker", flush=True)
391
- print("="*60, flush=True)
392
- env = Task3Environment()
393
- episodes = [await _run_t3_episode(env, SEED_BASE + i, i + 1) for i in range(n)]
394
- avg_s = sum(e["grader_score"] for e in episodes) / n
395
- print(f"\n Avg grader score : {avg_s:.3f}", flush=True)
396
- return {
397
- "task_id": "task3_rule_checker", "name": "Rule Checker",
398
- "status": "active", "num_episodes": n, "episodes": episodes,
399
- "avg_grader_score": avg_s
400
- }
401
-
402
 
403
  # ─────────────────────────────────────────────────────────────────────────────
404
  # Main
405
  # ─────────────────────────────────────────────────────────────────────────────
406
 
407
  async def main() -> None:
408
- """Async entry point (wraps sync env calls; asyncio.run() expected by caller)."""
409
  print("Smart Contract Audit RL Environment β€” Baseline Inference", flush=True)
410
 
411
  t1 = await run_task1(NUM_EPISODES)
412
  t2 = await run_task2(NUM_EPISODES)
413
  t3 = await run_task3(NUM_EPISODES)
414
 
415
- results: Dict[str, Any] = { "tasks": [t1, t2, t3] }
416
  overall = sum(t["avg_grader_score"] for t in results["tasks"]) / 3
417
  results["overall_avg_score"] = overall
418
 
419
- print("\n" + "="*60, flush=True)
420
  print("BASELINE SUMMARY", flush=True)
421
- print("="*60, flush=True)
422
  for t in results["tasks"]:
423
  print(f" βœ… {t['name']:40s}: {t['avg_grader_score']:.3f}", flush=True)
424
- print(f"\n Overall avg grader score: {overall:.3f}", flush=True)
425
 
426
  with open("baseline_scores.json", "w") as f:
427
  json.dump(results, f, indent=2)
 
7
  Emits mandatory structured stdout in the OpenEnv format.
8
 
9
  MANDATORY ENV VARS:
10
+ HF_TOKEN Hugging Face Token (required)
11
  MODEL_NAME Model identifier (default: openai/gpt-oss-20b)
12
 
13
  MANDATORY STDOUT FORMAT (per episode):
 
26
  import json
27
  import os
28
  import sys
29
+ from typing import Any, Dict, List, Optional, Callable, Awaitable, Union
 
30
 
31
  from openai import AsyncOpenAI
32
+ from dotenv import load_dotenv
33
 
34
  from server import Task1Environment, Task2Environment, Task3Environment
35
  from env.schemas import Action, ActionType
36
  from utils import T1_SYSTEM, T2_SYSTEM, T3_SYSTEM
 
37
 
38
  # ─────────────────────────────────────────────────────────────────────────────
39
  # Configuration
 
46
 
47
  if not HF_TOKEN:
48
  raise RuntimeError("HF_TOKEN environment variable not set")
49
+
50
  client = AsyncOpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
51
 
52
+ # from groq import AsyncGroq
53
+ # GROQ_API_KEY = os.getenv("GROQ_API_KEY")
54
+ # client = AsyncGroq(api_key=GROQ_API_KEY)
55
+
56
  # Benchmark / environment identifier (constant for this env)
57
  ENV_BENCHMARK = "smart-contract-audit"
58
 
59
  # Episodes per task
60
+ NUM_EPISODES = 5
61
  SEED_BASE = 42
62
 
63
+ # Max steps per task (same for all tasks)
64
+ MAX_STEPS = 35
 
 
65
 
66
  # A grader_score >= this is considered a "success" for the [END] line
67
  SUCCESS_SCORE_THRESHOLD = 0.5
68
 
69
+ # Throttle concurrent LLM calls
70
+ SEMAPHORE = asyncio.Semaphore(3)
71
+
72
+ # Timeout for each LLM request
73
+ LLM_TIMEOUT = 20
74
+
75
  # ─────────────────────────────────────────────────────────────────────────────
76
  # Unified LLM call function
77
  # ─────────────────────────────────────────────────────────────────────────────
 
86
  Returns the response content as a string.
87
  Raises an exception on failure (to be caught by the caller).
88
  """
89
+ try:
90
+ async with SEMAPHORE:
91
+ completion = await asyncio.wait_for(
92
+ client.chat.completions.create(
93
+ model=MODEL_NAME,
94
+ messages=messages, # type: ignore
95
+ ),
96
+ timeout=LLM_TIMEOUT,
97
+ )
98
+ return completion.choices[0].message.content.strip() # type: ignore
99
+
100
+ except asyncio.TimeoutError:
101
+ raise RuntimeError("LLM request timed out")
102
 
103
  # ─────────────────────────────────────────────────────────────────────────────
104
  # Mandatory stdout helpers
 
109
  print(f"[START] task={task} env={env} model={model}", flush=True)
110
 
111
 
112
+ def log_step( step: int, action: str, reward: float, done: bool, error: Optional[str] = None,
 
 
 
 
 
113
  ) -> None:
114
  """Emit a [STEP] line β€” one per env.step() call."""
115
  error_val = error if error else "null"
 
120
  )
121
 
122
 
123
+ def log_end( success: bool, steps: int, score: float, rewards: List[float]) -> None:
 
 
 
 
 
124
  """Emit the [END] line β€” one per episode, always emitted."""
125
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
126
  print(
 
129
  flush=True,
130
  )
131
 
 
132
  # ─────────────────────────────────────────────────────────────────────────────
133
+ # Generic episode runner
134
  # ─────────────────────────────────────────────────────────────────────────────
135
 
136
+ async def run_episode(
137
+ env: Union[Task1Environment, Task2Environment, Task3Environment],
138
+ seed: int,
139
+ ep_num: int,
140
+ *,
141
+ task_id: str,
142
+ system_prompt: str,
143
+ user_msg_formatter: Callable[[Dict[str, Any]], str],
144
+ max_tokens: int = 200,
145
+ default_action: ActionType = ActionType.LIST_FUNCTIONS,
146
+ extra_fields: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
147
+ ) -> Dict[str, Any]:
148
+ """
149
+ Run one episode with the given environment and task-specific parameters.
150
+ Emits [START]/[STEP]/[END] lines and returns a dict with episode results.
151
+ """
152
+ r = env.reset(seed=seed)
153
  obs = r.observation.model_dump()
154
 
155
+ log_start(task=task_id, env=ENV_BENCHMARK, model=MODEL_NAME)
156
 
157
  messages: List[Dict[str, str]] = [
158
+ {"role": "system", "content": system_prompt}
159
  ]
160
  step_rewards: List[float] = []
161
+ grader_score = 0.0
162
+ steps_taken = 0
163
  error_msg: Optional[str] = None
164
 
165
  try:
166
+ for step in range(1, MAX_STEPS + 1):
167
+ messages.append({"role": "user", "content": user_msg_formatter(obs)})
168
  try:
169
+ raw = await get_llm_response(messages, max_tokens=max_tokens, temperature=0.0)
170
  error_msg = None
171
  except Exception as e:
172
  raw = ""
173
  error_msg = str(e)[:80]
174
+ print(f"[DEBUG] {task_id} LLM error ep={ep_num} step={step}: {e}", file=sys.stderr)
175
 
176
  try:
177
  parsed = json.loads(raw)
178
+ at = ActionType(parsed["action"])
179
  params = parsed.get("params", {})
180
+ except Exception as e:
181
+ at, params = default_action, {}
182
+ print("Error in parsing LLM respoonse: " + str(e))
183
 
184
  messages.append({"role": "assistant", "content": raw})
185
  result = env.step(Action(action_type=at, params=params))
186
+ obs = result.observation.model_dump()
187
+ r_val = result.reward.value
188
+ done = result.done
189
 
190
  step_rewards.append(r_val)
191
  steps_taken = step
192
+ print(raw, at.value, r_val)
193
  log_step(step=step, action=at.value, reward=r_val, done=done, error=error_msg)
194
 
195
  if done:
196
  grader_score = r_val
197
  break
198
 
199
+ await asyncio.sleep(0.3)
200
 
201
  finally:
202
  success = grader_score >= SUCCESS_SCORE_THRESHOLD
203
  log_end(success=success, steps=steps_taken, score=grader_score, rewards=step_rewards)
204
 
205
+ result_dict = {
206
+ "episode": ep_num,
207
+ "seed": seed,
208
+ "grader_score": grader_score,
209
+ "contract": obs.get("contract_name", ""),
210
  }
211
+ if extra_fields:
212
+ result_dict.update(extra_fields(obs))
213
 
214
+ return result_dict
215
 
216
  # ─────────────────────────────────────────────────────────────────────────────
217
+ # Task-specific user message formatters and extra field extractors
218
  # ─────────────────────────────────────────────────────────────────────────────
219
 
220
+ def t1_user_msg(obs: Dict[str, Any]) -> str:
221
+ return (
222
+ f"Last action : {obs['last_action'] or 'None'}\n"
223
+ f"Last result : {obs['last_action_result'] or 'Episode just started.'}"
224
+ )
225
 
226
+ def t2_user_msg(obs: Dict[str, Any]) -> str:
227
  extra = obs.get("extra", {})
228
  return (
229
  f"Target Function : {extra.get('target_function', '?')} "
 
230
  f"Last action : {obs['last_action'] or 'None'}\n"
231
  f"Last result :\n{obs['last_action_result'] or 'Episode just started.'}"
232
  )
233
 
234
+ def t2_extra_fields(obs: Dict[str, Any]) -> Dict[str, Any]:
235
+ return {"function": obs.get("extra", {}).get("target_function", "?")}
236
 
237
+ def t3_user_msg(obs: Dict[str, Any]) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  extra = obs.get("extra", {})
239
  return (
240
  f"Verify Property : {extra.get('property_english', '(none)')}\n"
 
242
  f"Last result :\n{obs['last_action_result'] or 'Episode just started.'}"
243
  )
244
 
245
+ # ─────────────────────────────────────────────────────────────────────────────
246
+ # Generic task runner
247
+ # ─────────────────────────────────────────────────────────────────────────────
248
 
249
+ async def run_task(
250
+ task_id: str,
251
+ task_name: str,
252
+ env_class: type,
253
+ system_prompt: str,
254
+ user_msg_formatter: Callable[[Dict[str, Any]], str],
255
+ max_tokens: int = 200,
256
+ default_action: ActionType = ActionType.LIST_FUNCTIONS,
257
+ extra_fields: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
258
+ num_episodes: int = NUM_EPISODES,
259
+ ) -> Dict[str, Any]:
260
+ """Run multiple episodes for a given task and return aggregated results."""
261
+ print("\n" + "=" * 60, flush=True)
262
+ print(f"TASK: {task_name}", flush=True)
263
+ print("=" * 60, flush=True)
264
+
265
+ env = env_class()
266
+ tasks = [
267
+ run_episode(
268
+ env,
269
+ seed=SEED_BASE + i,
270
+ ep_num=i + 1,
271
+ task_id=task_id,
272
+ system_prompt=system_prompt,
273
+ user_msg_formatter=user_msg_formatter,
274
+ max_tokens=max_tokens,
275
+ default_action=default_action,
276
+ extra_fields=extra_fields,
277
+ )
278
+ for i in range(num_episodes)
279
  ]
280
+ episodes = await asyncio.gather(*tasks)
281
+ avg_score = sum(e["grader_score"] for e in episodes) / num_episodes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ print(f"\n Avg grader score : {avg_score:.4f}", flush=True)
284
  return {
285
+ "task_id": task_id,
286
+ "name": task_name,
287
+ "status": "active",
288
+ "num_episodes": num_episodes,
289
+ "episodes": episodes,
290
+ "avg_grader_score": avg_score,
291
  }
292
 
 
293
  # ─────────────────────────────────────────────────────────────────────────────
294
+ # Task-specific runners (thin wrappers for clarity)
295
  # ─────────────────────────────────────────────────────────────────────────────
296
 
297
  async def run_task1(n: int = NUM_EPISODES) -> Dict[str, Any]:
298
+ return await run_task(
299
+ task_id="task1_vuln_detection",
300
+ task_name="Targeted Vulnerability Detection",
301
+ env_class=Task1Environment,
302
+ system_prompt=T1_SYSTEM,
303
+ user_msg_formatter=t1_user_msg,
304
+ max_tokens=200,
305
+ default_action=ActionType.LIST_FUNCTIONS,
306
+ num_episodes=n,
307
+ )
 
 
 
308
 
309
  async def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]:
310
+ return await run_task(
311
+ task_id="task2_property_discovery",
312
+ task_name="Property Discovery",
313
+ env_class=Task2Environment,
314
+ system_prompt=T2_SYSTEM,
315
+ user_msg_formatter=t2_user_msg,
316
+ max_tokens=400,
317
+ default_action=ActionType.GET_FUNCTION_CODE,
318
+ extra_fields=t2_extra_fields,
319
+ num_episodes=n,
320
+ )
 
 
321
 
322
  async def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]:
323
+ return await run_task(
324
+ task_id="task3_rule_checker",
325
+ task_name="Rule Checker",
326
+ env_class=Task3Environment,
327
+ system_prompt=T3_SYSTEM,
328
+ user_msg_formatter=t3_user_msg,
329
+ max_tokens=200,
330
+ default_action=ActionType.LIST_FUNCTIONS,
331
+ num_episodes=n,
332
+ )
 
 
 
333
 
334
  # ─────────────────────────────────────────────────────────────────────────────
335
  # Main
336
  # ─────────────────────────────────────────────────────────────────────────────
337
 
338
  async def main() -> None:
339
+ """Async entry point."""
340
  print("Smart Contract Audit RL Environment β€” Baseline Inference", flush=True)
341
 
342
  t1 = await run_task1(NUM_EPISODES)
343
  t2 = await run_task2(NUM_EPISODES)
344
  t3 = await run_task3(NUM_EPISODES)
345
 
346
+ results: Dict[str, Any] = {"tasks": [t1, t2, t3]}
347
  overall = sum(t["avg_grader_score"] for t in results["tasks"]) / 3
348
  results["overall_avg_score"] = overall
349
 
350
+ print("\n" + "=" * 60, flush=True)
351
  print("BASELINE SUMMARY", flush=True)
352
+ print("=" * 60, flush=True)
353
  for t in results["tasks"]:
354
  print(f" βœ… {t['name']:40s}: {t['avg_grader_score']:.3f}", flush=True)
355
+ print(f"\n Overall avg grader score: {overall:.4f}", flush=True)
356
 
357
  with open("baseline_scores.json", "w") as f:
358
  json.dump(results, f, indent=2)
utils/prompts.py CHANGED
@@ -32,7 +32,8 @@ Common vulnerabilities in contracts:
32
  - denial of service
33
 
34
  Submit immediately once confident.
35
- Output: JSON only. No text.
 
36
  """
37
 
38
  T2_SYSTEM = """You are a Solidity formal methods engineer.
@@ -71,7 +72,8 @@ Format:
71
 
72
  Submit immediately once confident.
73
 
74
- Output: JSON only.
 
75
  """
76
 
77
  T3_SYSTEM = """You are a Solidity security auditor.
@@ -108,5 +110,6 @@ Example Violation heuristics:
108
  Select the function that clearly breaks the property.
109
  Submit immediately once confident.
110
 
111
- Output: JSON only.
 
112
  """
 
32
  - denial of service
33
 
34
  Submit immediately once confident.
35
+ Output: JSON only. No text. FOLLOW EXACT STRCUTURE OF ACTIONS GIVEN ANY CHANGE WILL LEAD TO
36
+ INVALID ACTION. It's case-sensitive as well.
37
  """
38
 
39
  T2_SYSTEM = """You are a Solidity formal methods engineer.
 
72
 
73
  Submit immediately once confident.
74
 
75
+ Output: JSON only. No text. FOLLOW EXACT STRCUTURE OF ACTIONS GIVEN ANY CHANGE WILL LEAD TO
76
+ INVALID ACTION. It's case-sensitive as well.
77
  """
78
 
79
  T3_SYSTEM = """You are a Solidity security auditor.
 
110
  Select the function that clearly breaks the property.
111
  Submit immediately once confident.
112
 
113
+ Output: JSON only. No text. FOLLOW EXACT STRCUTURE OF ACTIONS GIVEN ANY CHANGE WILL LEAD TO
114
+ INVALID ACTION. It's case-sensitive as well.
115
  """