iitian commited on
Commit
47ab3b8
·
1 Parent(s): 7b50b8a

fix: update grader scores to fall strictly within (0, 1)

Browse files
Files changed (3) hide show
  1. inference.py +6 -6
  2. server/environment.py +8 -8
  3. server/tasks.py +3 -3
inference.py CHANGED
@@ -191,7 +191,7 @@ def run_task(task: dict):
191
  obs = reset_data.get("observation", {})
192
  info = obs.get("info", "")
193
  except Exception as e:
194
- log_end(success=False, total_steps=0, score=0.0, rewards=[])
195
  return
196
 
197
  conversation = [
@@ -209,7 +209,7 @@ def run_task(task: dict):
209
  action = ask_llm(system_prompt, conversation)
210
  except Exception as e:
211
  last_error = f"LLM error: {str(e)}"
212
- log_step(step_num, {"error": "LLM failed"}, 0.0, True, error=last_error)
213
  break
214
 
215
  # Execute the action in the environment
@@ -221,14 +221,14 @@ def run_task(task: dict):
221
  last_error = obs.get("last_action_error")
222
  except Exception as e:
223
  last_error = f"Env error: {str(e)}"
224
- log_step(step_num, action, 0.0, True, error=last_error)
225
  break
226
 
227
  rewards.append(reward)
228
  log_step(step_num, action, reward, done, error=last_error)
229
 
230
  if done:
231
- success = (reward >= 1.0) # Assume 1.0 is full success
232
  break
233
 
234
  # Build observation summary for the LLM
@@ -250,8 +250,8 @@ def run_task(task: dict):
250
  conversation.append({"role": "assistant", "content": json.dumps(action)})
251
  conversation.append({"role": "user", "content": f"Observation from environment:\n{obs_text}\n\nDecide your next action."})
252
 
253
- # Calculate final score (normalized to [0, 1])
254
- final_score = max(0.0, min(1.0, sum(rewards)))
255
 
256
  log_end(success=success, total_steps=step_num, score=final_score, rewards=rewards)
257
 
 
191
  obs = reset_data.get("observation", {})
192
  info = obs.get("info", "")
193
  except Exception as e:
194
+ log_end(success=False, total_steps=0, score=0.01, rewards=[])
195
  return
196
 
197
  conversation = [
 
209
  action = ask_llm(system_prompt, conversation)
210
  except Exception as e:
211
  last_error = f"LLM error: {str(e)}"
212
+ log_step(step_num, {"error": "LLM failed"}, 0.01, True, error=last_error)
213
  break
214
 
215
  # Execute the action in the environment
 
221
  last_error = obs.get("last_action_error")
222
  except Exception as e:
223
  last_error = f"Env error: {str(e)}"
224
+ log_step(step_num, action, 0.01, True, error=last_error)
225
  break
226
 
227
  rewards.append(reward)
228
  log_step(step_num, action, reward, done, error=last_error)
229
 
230
  if done:
231
+ success = (reward >= 0.8) # Assume 0.8+ is full success (max is 0.85)
232
  break
233
 
234
  # Build observation summary for the LLM
 
250
  conversation.append({"role": "assistant", "content": json.dumps(action)})
251
  conversation.append({"role": "user", "content": f"Observation from environment:\n{obs_text}\n\nDecide your next action."})
252
 
253
+ # Calculate final score (normalized to (0, 1) to satisfy validator)
254
+ final_score = max(0.01, min(0.99, sum(rewards)))
255
 
256
  log_end(success=success, total_steps=step_num, score=final_score, rewards=rewards)
257
 
server/environment.py CHANGED
@@ -12,7 +12,7 @@ class CloudAuditEnv:
12
  self.episode_id = str(uuid.uuid4())
13
  self.step_count = 0
14
  self.is_completed = False
15
- self.score = 0.0
16
 
17
  # Mock Infrastructure
18
  self.resources = {
@@ -40,13 +40,13 @@ class CloudAuditEnv:
40
  """Required by openenv-core 0.1.1: takes task_id, returns JUST the observation."""
41
  self.task_id = task_id
42
  self._initialize_state()
43
- return CloudObservation(info=f"Environment reset. Task: {self.task_id}", reward=0.0, done=False)
44
 
45
  def step(self, action: CloudAction) -> CloudObservation:
46
  """Required by openenv-core 0.1.1: takes action, returns JUST the observation with reward/done fields."""
47
  try:
48
  self.step_count += 1
49
- reward = 0.0
50
  terminated = False
51
  truncated = self.step_count >= 20 # Limit steps
52
 
@@ -86,7 +86,7 @@ class CloudAuditEnv:
86
  rules = self.resources["ec2"][0]["security_groups"][0]["rules"]
87
  has_rdp = any(r["port"] == 3389 and r["cidr"] == "0.0.0.0/0" for r in rules)
88
  if not has_rdp:
89
- reward = 1.0
90
  terminated = True
91
  obs.info = "Success! Port 3389 removed. Task completed."
92
  else:
@@ -112,7 +112,7 @@ class CloudAuditEnv:
112
  answers = [a.strip() for a in action.answer.split(",")]
113
  expected = ["prod-data-001"]
114
  if set(answers) == set(expected):
115
- reward = 1.0
116
  terminated = True
117
  obs.info = "Correct! Task completed."
118
  else:
@@ -121,7 +121,7 @@ class CloudAuditEnv:
121
  elif self.task_id == "hard":
122
  # Expecting rogue IP from auth-logs
123
  if action.answer and action.answer.strip() == "192.168.1.50":
124
- reward = 1.0
125
  terminated = True
126
  obs.info = "Correct! Rogue IP identified. Task completed."
127
  else:
@@ -130,7 +130,7 @@ class CloudAuditEnv:
130
  elif self.task_id == "medium":
131
  obs.info = "For the medium task, use the 'modify' action to update the EC2 security group, not 'submit'."
132
 
133
- self.score += reward
134
  obs.reward = reward
135
  obs.done = terminated or truncated
136
  return obs
@@ -139,7 +139,7 @@ class CloudAuditEnv:
139
  import traceback
140
  print(f"ERROR in environment.step: {str(e)}", file=sys.stderr)
141
  traceback.print_exc(file=sys.stderr)
142
- return CloudObservation(status=f"Internal Server Error: {str(e)}", reward=0.0, done=True)
143
 
144
  def state(self) -> CloudState:
145
  return CloudState(
 
12
  self.episode_id = str(uuid.uuid4())
13
  self.step_count = 0
14
  self.is_completed = False
15
+ self.score = 0.01
16
 
17
  # Mock Infrastructure
18
  self.resources = {
 
40
  """Required by openenv-core 0.1.1: takes task_id, returns JUST the observation."""
41
  self.task_id = task_id
42
  self._initialize_state()
43
+ return CloudObservation(info=f"Environment reset. Task: {self.task_id}", reward=0.01, done=False)
44
 
45
  def step(self, action: CloudAction) -> CloudObservation:
46
  """Required by openenv-core 0.1.1: takes action, returns JUST the observation with reward/done fields."""
47
  try:
48
  self.step_count += 1
49
+ reward = 0.005
50
  terminated = False
51
  truncated = self.step_count >= 20 # Limit steps
52
 
 
86
  rules = self.resources["ec2"][0]["security_groups"][0]["rules"]
87
  has_rdp = any(r["port"] == 3389 and r["cidr"] == "0.0.0.0/0" for r in rules)
88
  if not has_rdp:
89
+ reward = 0.85
90
  terminated = True
91
  obs.info = "Success! Port 3389 removed. Task completed."
92
  else:
 
112
  answers = [a.strip() for a in action.answer.split(",")]
113
  expected = ["prod-data-001"]
114
  if set(answers) == set(expected):
115
+ reward = 0.85
116
  terminated = True
117
  obs.info = "Correct! Task completed."
118
  else:
 
121
  elif self.task_id == "hard":
122
  # Expecting rogue IP from auth-logs
123
  if action.answer and action.answer.strip() == "192.168.1.50":
124
+ reward = 0.85
125
  terminated = True
126
  obs.info = "Correct! Rogue IP identified. Task completed."
127
  else:
 
130
  elif self.task_id == "medium":
131
  obs.info = "For the medium task, use the 'modify' action to update the EC2 security group, not 'submit'."
132
 
133
+ self.score = min(0.99, self.score + reward)
134
  obs.reward = reward
135
  obs.done = terminated or truncated
136
  return obs
 
139
  import traceback
140
  print(f"ERROR in environment.step: {str(e)}", file=sys.stderr)
141
  traceback.print_exc(file=sys.stderr)
142
+ return CloudObservation(status=f"Internal Server Error: {str(e)}", reward=0.01, done=True)
143
 
144
  def state(self) -> CloudState:
145
  return CloudState(
server/tasks.py CHANGED
@@ -5,19 +5,19 @@ TASKS = {
5
  "id": "easy",
6
  "description": "Identify all public S3 buckets in the 'prod' region. Submit their IDs as a comma-separated string.",
7
  "difficulty": "Easy",
8
- "reward_threshold": 1.0,
9
  },
10
  "medium": {
11
  "id": "medium",
12
  "description": "Find the EC2 instance with ID 'i-0abcdef1234567890' in the 'dev' environment. Close port 3389 (RDP) by updating its security group rules to only allow port 22.",
13
  "difficulty": "Medium",
14
- "reward_threshold": 1.0,
15
  },
16
  "hard": {
17
  "id": "hard",
18
  "description": "A rogue IAM role 'iam-role-01' has been performing unauthorized actions. Fetch the 'auth-logs' and identify the IP address that performed 'DeleteStorage'. Submit the IP address.",
19
  "difficulty": "Hard",
20
- "reward_threshold": 1.0,
21
  }
22
  }
23
 
 
5
  "id": "easy",
6
  "description": "Identify all public S3 buckets in the 'prod' region. Submit their IDs as a comma-separated string.",
7
  "difficulty": "Easy",
8
+ "reward_threshold": 0.8,
9
  },
10
  "medium": {
11
  "id": "medium",
12
  "description": "Find the EC2 instance with ID 'i-0abcdef1234567890' in the 'dev' environment. Close port 3389 (RDP) by updating its security group rules to only allow port 22.",
13
  "difficulty": "Medium",
14
+ "reward_threshold": 0.8,
15
  },
16
  "hard": {
17
  "id": "hard",
18
  "description": "A rogue IAM role 'iam-role-01' has been performing unauthorized actions. Fetch the 'auth-logs' and identify the IP address that performed 'DeleteStorage'. Submit the IP address.",
19
  "difficulty": "Hard",
20
+ "reward_threshold": 0.8,
21
  }
22
  }
23