Spaces:
Sleeping
Sleeping
Adhitya-Vardhan commited on
Commit ·
f00a888
1
Parent(s): 61b9118
fix: structured [START]/[STEP]/[END] output format for validator
Browse files- [START] task=NAME
- [STEP] step=N action=ACTION reward=R (with flush=True)
- [END] task=NAME score=S steps=N
All 4 tasks score 1.0 with heuristic policy.
- inference.py +12 -12
inference.py
CHANGED
|
@@ -190,8 +190,7 @@ def sanitize_action_payload(payload: Dict) -> Dict:
|
|
| 190 |
|
| 191 |
|
| 192 |
def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, float]:
|
| 193 |
-
print(f"START")
|
| 194 |
-
print(f"Task: {task_id}")
|
| 195 |
env = VulnTriageEnvironment()
|
| 196 |
observation = env.reset(task_id=task_id).model_dump()
|
| 197 |
client = get_openai_client() if policy == "openai" else None
|
|
@@ -200,7 +199,6 @@ def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, f
|
|
| 200 |
step_num: int = 1
|
| 201 |
|
| 202 |
while not observation["done"]:
|
| 203 |
-
print(f"STEP")
|
| 204 |
action_payload = (
|
| 205 |
llm_policy(client, model_name, observation) if client else heuristic_policy(observation)
|
| 206 |
)
|
|
@@ -209,7 +207,7 @@ def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, f
|
|
| 209 |
clean = sanitize_action_payload(action_payload)
|
| 210 |
action = VulnTriageAction.model_validate(clean)
|
| 211 |
except Exception as exc:
|
| 212 |
-
print(f" [warn] invalid action payload ({exc}), falling back to read_report")
|
| 213 |
action = VulnTriageAction(action_type="read_report", rationale="fallback: parse error")
|
| 214 |
|
| 215 |
# Break infinite loops where model repeats the same action
|
|
@@ -217,17 +215,19 @@ def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, f
|
|
| 217 |
if action_str == last_action_str:
|
| 218 |
repeat_count += 1
|
| 219 |
if repeat_count >= 3:
|
| 220 |
-
print(f" [warn] model repeated same action 3x — forcing submit_triage")
|
| 221 |
action = VulnTriageAction(action_type="submit_triage", rationale="loop guard")
|
| 222 |
else:
|
| 223 |
repeat_count = 0
|
| 224 |
last_action_str = action_str
|
| 225 |
|
| 226 |
-
print(f"Action: {action.action_type}")
|
| 227 |
observation = env.step(action).model_dump()
|
|
|
|
|
|
|
| 228 |
step_num += 1
|
| 229 |
|
| 230 |
-
|
|
|
|
| 231 |
|
| 232 |
return {
|
| 233 |
"task_id": task_id,
|
|
@@ -248,8 +248,7 @@ def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, f
|
|
| 248 |
|
| 249 |
|
| 250 |
def run_remote_episode(base_url: str, task_id: str, policy: str, model_name: str) -> Dict[str, float]:
|
| 251 |
-
print(f"START")
|
| 252 |
-
print(f"Task: {task_id}")
|
| 253 |
llm_client = get_openai_client() if policy == "openai" else None
|
| 254 |
env = GenericEnvClient(base_url=base_url).sync()
|
| 255 |
with env:
|
|
@@ -258,19 +257,20 @@ def run_remote_episode(base_url: str, task_id: str, policy: str, model_name: str
|
|
| 258 |
done = response.done
|
| 259 |
step_num: int = 1
|
| 260 |
while not done:
|
| 261 |
-
print(f"STEP")
|
| 262 |
action_payload = (
|
| 263 |
llm_policy(llm_client, model_name, observation)
|
| 264 |
if llm_client
|
| 265 |
else heuristic_policy(observation)
|
| 266 |
)
|
| 267 |
-
print(f"Action: {action_payload.get('action_type')}")
|
| 268 |
response = env.step(action_payload)
|
| 269 |
observation = response.observation
|
| 270 |
done = response.done
|
|
|
|
|
|
|
| 271 |
step_num += 1
|
| 272 |
|
| 273 |
-
|
|
|
|
| 274 |
|
| 275 |
final_score = float(observation.get("final_score") or 0.0)
|
| 276 |
return {
|
|
|
|
| 190 |
|
| 191 |
|
| 192 |
def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, float]:
|
| 193 |
+
print(f"[START] task={task_id}", flush=True)
|
|
|
|
| 194 |
env = VulnTriageEnvironment()
|
| 195 |
observation = env.reset(task_id=task_id).model_dump()
|
| 196 |
client = get_openai_client() if policy == "openai" else None
|
|
|
|
| 199 |
step_num: int = 1
|
| 200 |
|
| 201 |
while not observation["done"]:
|
|
|
|
| 202 |
action_payload = (
|
| 203 |
llm_policy(client, model_name, observation) if client else heuristic_policy(observation)
|
| 204 |
)
|
|
|
|
| 207 |
clean = sanitize_action_payload(action_payload)
|
| 208 |
action = VulnTriageAction.model_validate(clean)
|
| 209 |
except Exception as exc:
|
| 210 |
+
print(f" [warn] invalid action payload ({exc}), falling back to read_report", flush=True)
|
| 211 |
action = VulnTriageAction(action_type="read_report", rationale="fallback: parse error")
|
| 212 |
|
| 213 |
# Break infinite loops where model repeats the same action
|
|
|
|
| 215 |
if action_str == last_action_str:
|
| 216 |
repeat_count += 1
|
| 217 |
if repeat_count >= 3:
|
| 218 |
+
print(f" [warn] model repeated same action 3x — forcing submit_triage", flush=True)
|
| 219 |
action = VulnTriageAction(action_type="submit_triage", rationale="loop guard")
|
| 220 |
else:
|
| 221 |
repeat_count = 0
|
| 222 |
last_action_str = action_str
|
| 223 |
|
|
|
|
| 224 |
observation = env.step(action).model_dump()
|
| 225 |
+
step_reward = float(observation.get("reward") or 0.0)
|
| 226 |
+
print(f"[STEP] step={step_num} action={action.action_type} reward={step_reward}", flush=True)
|
| 227 |
step_num += 1
|
| 228 |
|
| 229 |
+
final_score = float(observation.get("final_score") or observation.get("score_breakdown", {}).get("total", 0.0))
|
| 230 |
+
print(f"[END] task={task_id} score={final_score} steps={step_num}", flush=True)
|
| 231 |
|
| 232 |
return {
|
| 233 |
"task_id": task_id,
|
|
|
|
| 248 |
|
| 249 |
|
| 250 |
def run_remote_episode(base_url: str, task_id: str, policy: str, model_name: str) -> Dict[str, float]:
|
| 251 |
+
print(f"[START] task={task_id}", flush=True)
|
|
|
|
| 252 |
llm_client = get_openai_client() if policy == "openai" else None
|
| 253 |
env = GenericEnvClient(base_url=base_url).sync()
|
| 254 |
with env:
|
|
|
|
| 257 |
done = response.done
|
| 258 |
step_num: int = 1
|
| 259 |
while not done:
|
|
|
|
| 260 |
action_payload = (
|
| 261 |
llm_policy(llm_client, model_name, observation)
|
| 262 |
if llm_client
|
| 263 |
else heuristic_policy(observation)
|
| 264 |
)
|
|
|
|
| 265 |
response = env.step(action_payload)
|
| 266 |
observation = response.observation
|
| 267 |
done = response.done
|
| 268 |
+
step_reward = float(getattr(response, 'reward', None) or 0.0)
|
| 269 |
+
print(f"[STEP] step={step_num} action={action_payload.get('action_type')} reward={step_reward}", flush=True)
|
| 270 |
step_num += 1
|
| 271 |
|
| 272 |
+
final_score = float(observation.get("final_score") or observation.get("score_breakdown", {}).get("total", 0.0))
|
| 273 |
+
print(f"[END] task={task_id} score={final_score} steps={step_num}", flush=True)
|
| 274 |
|
| 275 |
final_score = float(observation.get("final_score") or 0.0)
|
| 276 |
return {
|