Adhitya-Vardhan commited on
Commit
f00a888
·
1 Parent(s): 61b9118

fix: structured [START]/[STEP]/[END] output format for validator

Browse files

- [START] task=NAME
- [STEP] step=N action=ACTION reward=R (with flush=True)
- [END] task=NAME score=S steps=N
All 4 tasks score 1.0 with heuristic policy.

Files changed (1) hide show
  1. inference.py +12 -12
inference.py CHANGED
@@ -190,8 +190,7 @@ def sanitize_action_payload(payload: Dict) -> Dict:
190
 
191
 
192
  def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, float]:
193
- print(f"START")
194
- print(f"Task: {task_id}")
195
  env = VulnTriageEnvironment()
196
  observation = env.reset(task_id=task_id).model_dump()
197
  client = get_openai_client() if policy == "openai" else None
@@ -200,7 +199,6 @@ def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, f
200
  step_num: int = 1
201
 
202
  while not observation["done"]:
203
- print(f"STEP")
204
  action_payload = (
205
  llm_policy(client, model_name, observation) if client else heuristic_policy(observation)
206
  )
@@ -209,7 +207,7 @@ def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, f
209
  clean = sanitize_action_payload(action_payload)
210
  action = VulnTriageAction.model_validate(clean)
211
  except Exception as exc:
212
- print(f" [warn] invalid action payload ({exc}), falling back to read_report")
213
  action = VulnTriageAction(action_type="read_report", rationale="fallback: parse error")
214
 
215
  # Break infinite loops where model repeats the same action
@@ -217,17 +215,19 @@ def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, f
217
  if action_str == last_action_str:
218
  repeat_count += 1
219
  if repeat_count >= 3:
220
- print(f" [warn] model repeated same action 3x — forcing submit_triage")
221
  action = VulnTriageAction(action_type="submit_triage", rationale="loop guard")
222
  else:
223
  repeat_count = 0
224
  last_action_str = action_str
225
 
226
- print(f"Action: {action.action_type}")
227
  observation = env.step(action).model_dump()
 
 
228
  step_num += 1
229
 
230
- print(f"END")
 
231
 
232
  return {
233
  "task_id": task_id,
@@ -248,8 +248,7 @@ def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, f
248
 
249
 
250
  def run_remote_episode(base_url: str, task_id: str, policy: str, model_name: str) -> Dict[str, float]:
251
- print(f"START")
252
- print(f"Task: {task_id}")
253
  llm_client = get_openai_client() if policy == "openai" else None
254
  env = GenericEnvClient(base_url=base_url).sync()
255
  with env:
@@ -258,19 +257,20 @@ def run_remote_episode(base_url: str, task_id: str, policy: str, model_name: str
258
  done = response.done
259
  step_num: int = 1
260
  while not done:
261
- print(f"STEP")
262
  action_payload = (
263
  llm_policy(llm_client, model_name, observation)
264
  if llm_client
265
  else heuristic_policy(observation)
266
  )
267
- print(f"Action: {action_payload.get('action_type')}")
268
  response = env.step(action_payload)
269
  observation = response.observation
270
  done = response.done
 
 
271
  step_num += 1
272
 
273
- print(f"END")
 
274
 
275
  final_score = float(observation.get("final_score") or 0.0)
276
  return {
 
190
 
191
 
192
  def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, float]:
193
+ print(f"[START] task={task_id}", flush=True)
 
194
  env = VulnTriageEnvironment()
195
  observation = env.reset(task_id=task_id).model_dump()
196
  client = get_openai_client() if policy == "openai" else None
 
199
  step_num: int = 1
200
 
201
  while not observation["done"]:
 
202
  action_payload = (
203
  llm_policy(client, model_name, observation) if client else heuristic_policy(observation)
204
  )
 
207
  clean = sanitize_action_payload(action_payload)
208
  action = VulnTriageAction.model_validate(clean)
209
  except Exception as exc:
210
+ print(f" [warn] invalid action payload ({exc}), falling back to read_report", flush=True)
211
  action = VulnTriageAction(action_type="read_report", rationale="fallback: parse error")
212
 
213
  # Break infinite loops where model repeats the same action
 
215
  if action_str == last_action_str:
216
  repeat_count += 1
217
  if repeat_count >= 3:
218
+ print(f" [warn] model repeated same action 3x — forcing submit_triage", flush=True)
219
  action = VulnTriageAction(action_type="submit_triage", rationale="loop guard")
220
  else:
221
  repeat_count = 0
222
  last_action_str = action_str
223
 
 
224
  observation = env.step(action).model_dump()
225
+ step_reward = float(observation.get("reward") or 0.0)
226
+ print(f"[STEP] step={step_num} action={action.action_type} reward={step_reward}", flush=True)
227
  step_num += 1
228
 
229
+ final_score = float(observation.get("final_score") or observation.get("score_breakdown", {}).get("total", 0.0))
230
+ print(f"[END] task={task_id} score={final_score} steps={step_num}", flush=True)
231
 
232
  return {
233
  "task_id": task_id,
 
248
 
249
 
250
  def run_remote_episode(base_url: str, task_id: str, policy: str, model_name: str) -> Dict[str, float]:
251
+ print(f"[START] task={task_id}", flush=True)
 
252
  llm_client = get_openai_client() if policy == "openai" else None
253
  env = GenericEnvClient(base_url=base_url).sync()
254
  with env:
 
257
  done = response.done
258
  step_num: int = 1
259
  while not done:
 
260
  action_payload = (
261
  llm_policy(llm_client, model_name, observation)
262
  if llm_client
263
  else heuristic_policy(observation)
264
  )
 
265
  response = env.step(action_payload)
266
  observation = response.observation
267
  done = response.done
268
+ step_reward = float(getattr(response, 'reward', None) or 0.0)
269
+ print(f"[STEP] step={step_num} action={action_payload.get('action_type')} reward={step_reward}", flush=True)
270
  step_num += 1
271
 
272
+ final_score = float(observation.get("final_score") or observation.get("score_breakdown", {}).get("total", 0.0))
273
+ print(f"[END] task={task_id} score={final_score} steps={step_num}", flush=True)
274
 
275
  final_score = float(observation.get("final_score") or 0.0)
276
  return {