fix: update grader scores to fall strictly within (0, 1)
Browse files- inference.py +6 -6
- server/environment.py +8 -8
- server/tasks.py +3 -3
inference.py
CHANGED
|
@@ -191,7 +191,7 @@ def run_task(task: dict):
|
|
| 191 |
obs = reset_data.get("observation", {})
|
| 192 |
info = obs.get("info", "")
|
| 193 |
except Exception as e:
|
| 194 |
-
log_end(success=False, total_steps=0, score=0.
|
| 195 |
return
|
| 196 |
|
| 197 |
conversation = [
|
|
@@ -209,7 +209,7 @@ def run_task(task: dict):
|
|
| 209 |
action = ask_llm(system_prompt, conversation)
|
| 210 |
except Exception as e:
|
| 211 |
last_error = f"LLM error: {str(e)}"
|
| 212 |
-
log_step(step_num, {"error": "LLM failed"}, 0.
|
| 213 |
break
|
| 214 |
|
| 215 |
# Execute the action in the environment
|
|
@@ -221,14 +221,14 @@ def run_task(task: dict):
|
|
| 221 |
last_error = obs.get("last_action_error")
|
| 222 |
except Exception as e:
|
| 223 |
last_error = f"Env error: {str(e)}"
|
| 224 |
-
log_step(step_num, action, 0.
|
| 225 |
break
|
| 226 |
|
| 227 |
rewards.append(reward)
|
| 228 |
log_step(step_num, action, reward, done, error=last_error)
|
| 229 |
|
| 230 |
if done:
|
| 231 |
-
success = (reward >=
|
| 232 |
break
|
| 233 |
|
| 234 |
# Build observation summary for the LLM
|
|
@@ -250,8 +250,8 @@ def run_task(task: dict):
|
|
| 250 |
conversation.append({"role": "assistant", "content": json.dumps(action)})
|
| 251 |
conversation.append({"role": "user", "content": f"Observation from environment:\n{obs_text}\n\nDecide your next action."})
|
| 252 |
|
| 253 |
-
# Calculate final score (normalized to
|
| 254 |
-
final_score = max(0.
|
| 255 |
|
| 256 |
log_end(success=success, total_steps=step_num, score=final_score, rewards=rewards)
|
| 257 |
|
|
|
|
| 191 |
obs = reset_data.get("observation", {})
|
| 192 |
info = obs.get("info", "")
|
| 193 |
except Exception as e:
|
| 194 |
+
log_end(success=False, total_steps=0, score=0.01, rewards=[])
|
| 195 |
return
|
| 196 |
|
| 197 |
conversation = [
|
|
|
|
| 209 |
action = ask_llm(system_prompt, conversation)
|
| 210 |
except Exception as e:
|
| 211 |
last_error = f"LLM error: {str(e)}"
|
| 212 |
+
log_step(step_num, {"error": "LLM failed"}, 0.01, True, error=last_error)
|
| 213 |
break
|
| 214 |
|
| 215 |
# Execute the action in the environment
|
|
|
|
| 221 |
last_error = obs.get("last_action_error")
|
| 222 |
except Exception as e:
|
| 223 |
last_error = f"Env error: {str(e)}"
|
| 224 |
+
log_step(step_num, action, 0.01, True, error=last_error)
|
| 225 |
break
|
| 226 |
|
| 227 |
rewards.append(reward)
|
| 228 |
log_step(step_num, action, reward, done, error=last_error)
|
| 229 |
|
| 230 |
if done:
|
| 231 |
+
success = (reward >= 0.8) # Assume 0.8+ is full success (max is 0.85)
|
| 232 |
break
|
| 233 |
|
| 234 |
# Build observation summary for the LLM
|
|
|
|
| 250 |
conversation.append({"role": "assistant", "content": json.dumps(action)})
|
| 251 |
conversation.append({"role": "user", "content": f"Observation from environment:\n{obs_text}\n\nDecide your next action."})
|
| 252 |
|
| 253 |
+
# Calculate final score (normalized to (0, 1) to satisfy validator)
|
| 254 |
+
final_score = max(0.01, min(0.99, sum(rewards)))
|
| 255 |
|
| 256 |
log_end(success=success, total_steps=step_num, score=final_score, rewards=rewards)
|
| 257 |
|
server/environment.py
CHANGED
|
@@ -12,7 +12,7 @@ class CloudAuditEnv:
|
|
| 12 |
self.episode_id = str(uuid.uuid4())
|
| 13 |
self.step_count = 0
|
| 14 |
self.is_completed = False
|
| 15 |
-
self.score = 0.
|
| 16 |
|
| 17 |
# Mock Infrastructure
|
| 18 |
self.resources = {
|
|
@@ -40,13 +40,13 @@ class CloudAuditEnv:
|
|
| 40 |
"""Required by openenv-core 0.1.1: takes task_id, returns JUST the observation."""
|
| 41 |
self.task_id = task_id
|
| 42 |
self._initialize_state()
|
| 43 |
-
return CloudObservation(info=f"Environment reset. Task: {self.task_id}", reward=0.
|
| 44 |
|
| 45 |
def step(self, action: CloudAction) -> CloudObservation:
|
| 46 |
"""Required by openenv-core 0.1.1: takes action, returns JUST the observation with reward/done fields."""
|
| 47 |
try:
|
| 48 |
self.step_count += 1
|
| 49 |
-
reward = 0.
|
| 50 |
terminated = False
|
| 51 |
truncated = self.step_count >= 20 # Limit steps
|
| 52 |
|
|
@@ -86,7 +86,7 @@ class CloudAuditEnv:
|
|
| 86 |
rules = self.resources["ec2"][0]["security_groups"][0]["rules"]
|
| 87 |
has_rdp = any(r["port"] == 3389 and r["cidr"] == "0.0.0.0/0" for r in rules)
|
| 88 |
if not has_rdp:
|
| 89 |
-
reward =
|
| 90 |
terminated = True
|
| 91 |
obs.info = "Success! Port 3389 removed. Task completed."
|
| 92 |
else:
|
|
@@ -112,7 +112,7 @@ class CloudAuditEnv:
|
|
| 112 |
answers = [a.strip() for a in action.answer.split(",")]
|
| 113 |
expected = ["prod-data-001"]
|
| 114 |
if set(answers) == set(expected):
|
| 115 |
-
reward =
|
| 116 |
terminated = True
|
| 117 |
obs.info = "Correct! Task completed."
|
| 118 |
else:
|
|
@@ -121,7 +121,7 @@ class CloudAuditEnv:
|
|
| 121 |
elif self.task_id == "hard":
|
| 122 |
# Expecting rogue IP from auth-logs
|
| 123 |
if action.answer and action.answer.strip() == "192.168.1.50":
|
| 124 |
-
reward =
|
| 125 |
terminated = True
|
| 126 |
obs.info = "Correct! Rogue IP identified. Task completed."
|
| 127 |
else:
|
|
@@ -130,7 +130,7 @@ class CloudAuditEnv:
|
|
| 130 |
elif self.task_id == "medium":
|
| 131 |
obs.info = "For the medium task, use the 'modify' action to update the EC2 security group, not 'submit'."
|
| 132 |
|
| 133 |
-
self.score
|
| 134 |
obs.reward = reward
|
| 135 |
obs.done = terminated or truncated
|
| 136 |
return obs
|
|
@@ -139,7 +139,7 @@ class CloudAuditEnv:
|
|
| 139 |
import traceback
|
| 140 |
print(f"ERROR in environment.step: {str(e)}", file=sys.stderr)
|
| 141 |
traceback.print_exc(file=sys.stderr)
|
| 142 |
-
return CloudObservation(status=f"Internal Server Error: {str(e)}", reward=0.
|
| 143 |
|
| 144 |
def state(self) -> CloudState:
|
| 145 |
return CloudState(
|
|
|
|
| 12 |
self.episode_id = str(uuid.uuid4())
|
| 13 |
self.step_count = 0
|
| 14 |
self.is_completed = False
|
| 15 |
+
self.score = 0.01
|
| 16 |
|
| 17 |
# Mock Infrastructure
|
| 18 |
self.resources = {
|
|
|
|
| 40 |
"""Required by openenv-core 0.1.1: takes task_id, returns JUST the observation."""
|
| 41 |
self.task_id = task_id
|
| 42 |
self._initialize_state()
|
| 43 |
+
return CloudObservation(info=f"Environment reset. Task: {self.task_id}", reward=0.01, done=False)
|
| 44 |
|
| 45 |
def step(self, action: CloudAction) -> CloudObservation:
|
| 46 |
"""Required by openenv-core 0.1.1: takes action, returns JUST the observation with reward/done fields."""
|
| 47 |
try:
|
| 48 |
self.step_count += 1
|
| 49 |
+
reward = 0.005
|
| 50 |
terminated = False
|
| 51 |
truncated = self.step_count >= 20 # Limit steps
|
| 52 |
|
|
|
|
| 86 |
rules = self.resources["ec2"][0]["security_groups"][0]["rules"]
|
| 87 |
has_rdp = any(r["port"] == 3389 and r["cidr"] == "0.0.0.0/0" for r in rules)
|
| 88 |
if not has_rdp:
|
| 89 |
+
reward = 0.85
|
| 90 |
terminated = True
|
| 91 |
obs.info = "Success! Port 3389 removed. Task completed."
|
| 92 |
else:
|
|
|
|
| 112 |
answers = [a.strip() for a in action.answer.split(",")]
|
| 113 |
expected = ["prod-data-001"]
|
| 114 |
if set(answers) == set(expected):
|
| 115 |
+
reward = 0.85
|
| 116 |
terminated = True
|
| 117 |
obs.info = "Correct! Task completed."
|
| 118 |
else:
|
|
|
|
| 121 |
elif self.task_id == "hard":
|
| 122 |
# Expecting rogue IP from auth-logs
|
| 123 |
if action.answer and action.answer.strip() == "192.168.1.50":
|
| 124 |
+
reward = 0.85
|
| 125 |
terminated = True
|
| 126 |
obs.info = "Correct! Rogue IP identified. Task completed."
|
| 127 |
else:
|
|
|
|
| 130 |
elif self.task_id == "medium":
|
| 131 |
obs.info = "For the medium task, use the 'modify' action to update the EC2 security group, not 'submit'."
|
| 132 |
|
| 133 |
+
self.score = min(0.99, self.score + reward)
|
| 134 |
obs.reward = reward
|
| 135 |
obs.done = terminated or truncated
|
| 136 |
return obs
|
|
|
|
| 139 |
import traceback
|
| 140 |
print(f"ERROR in environment.step: {str(e)}", file=sys.stderr)
|
| 141 |
traceback.print_exc(file=sys.stderr)
|
| 142 |
+
return CloudObservation(status=f"Internal Server Error: {str(e)}", reward=0.01, done=True)
|
| 143 |
|
| 144 |
def state(self) -> CloudState:
|
| 145 |
return CloudState(
|
server/tasks.py
CHANGED
|
@@ -5,19 +5,19 @@ TASKS = {
|
|
| 5 |
"id": "easy",
|
| 6 |
"description": "Identify all public S3 buckets in the 'prod' region. Submit their IDs as a comma-separated string.",
|
| 7 |
"difficulty": "Easy",
|
| 8 |
-
"reward_threshold":
|
| 9 |
},
|
| 10 |
"medium": {
|
| 11 |
"id": "medium",
|
| 12 |
"description": "Find the EC2 instance with ID 'i-0abcdef1234567890' in the 'dev' environment. Close port 3389 (RDP) by updating its security group rules to only allow port 22.",
|
| 13 |
"difficulty": "Medium",
|
| 14 |
-
"reward_threshold":
|
| 15 |
},
|
| 16 |
"hard": {
|
| 17 |
"id": "hard",
|
| 18 |
"description": "A rogue IAM role 'iam-role-01' has been performing unauthorized actions. Fetch the 'auth-logs' and identify the IP address that performed 'DeleteStorage'. Submit the IP address.",
|
| 19 |
"difficulty": "Hard",
|
| 20 |
-
"reward_threshold":
|
| 21 |
}
|
| 22 |
}
|
| 23 |
|
|
|
|
| 5 |
"id": "easy",
|
| 6 |
"description": "Identify all public S3 buckets in the 'prod' region. Submit their IDs as a comma-separated string.",
|
| 7 |
"difficulty": "Easy",
|
| 8 |
+
"reward_threshold": 0.8,
|
| 9 |
},
|
| 10 |
"medium": {
|
| 11 |
"id": "medium",
|
| 12 |
"description": "Find the EC2 instance with ID 'i-0abcdef1234567890' in the 'dev' environment. Close port 3389 (RDP) by updating its security group rules to only allow port 22.",
|
| 13 |
"difficulty": "Medium",
|
| 14 |
+
"reward_threshold": 0.8,
|
| 15 |
},
|
| 16 |
"hard": {
|
| 17 |
"id": "hard",
|
| 18 |
"description": "A rogue IAM role 'iam-role-01' has been performing unauthorized actions. Fetch the 'auth-logs' and identify the IP address that performed 'DeleteStorage'. Submit the IP address.",
|
| 19 |
"difficulty": "Hard",
|
| 20 |
+
"reward_threshold": 0.8,
|
| 21 |
}
|
| 22 |
}
|
| 23 |
|