Imaginephoenix commited on
Commit
48058c0
·
verified ·
1 Parent(s): b409c37

Delete environment.py

Browse files
Files changed (1) hide show
  1. environment.py +0 -419
environment.py DELETED
@@ -1,419 +0,0 @@
1
- """Core OpenEnv email triage environment implementation."""
2
-
3
- import os
4
- from typing import cast
5
-
6
- from pydantic import ValidationError
7
-
8
- from graders import grade_easy, grade_hard, grade_medium_step
9
- from models import (
10
- EmailObservation,
11
- EnvironmentState,
12
- ResetResult,
13
- RewardResult,
14
- StepResult,
15
- TriageAction,
16
- )
17
- from tasks import get_task_definition
18
-
19
-
20
- class EmailTriageEnv:
21
- """Deterministic email triage environment implementing reset, step, and state."""
22
-
23
- def __init__(
24
- self,
25
- task_id: str,
26
- scenario_index: int = 0,
27
- split: str | None = None,
28
- runtime_options: dict[str, object] | None = None,
29
- ) -> None:
30
- """Initialize environment with a selected task.
31
-
32
- Args:
33
- task_id: Task identifier such as task_easy, task_medium, or task_hard.
34
- scenario_index: Deterministic scenario index within the task pool.
35
- split: Scenario split, either public or private_eval.
36
- runtime_options: Optional deterministic runtime controls for task generation.
37
- """
38
- self.task_id = task_id
39
- self._episode_index = max(0, scenario_index)
40
- self.split = split or os.getenv("OPENENV_EVAL_SPLIT", "public")
41
- self.runtime_options = runtime_options or {}
42
- self._task_definition = get_task_definition(
43
- task_id,
44
- self._episode_index,
45
- self.split,
46
- self.runtime_options,
47
- )
48
- self._scenario_id = str(self._task_definition.get("scenario_id", "unknown"))
49
- self._emails = cast(list[dict[str, object]], self._task_definition.get("emails", []))
50
- self._ground_truth = cast(
51
- list[dict[str, object]], self._task_definition.get("ground_truth", [])
52
- )
53
-
54
- self._current_index = 0
55
- self._current_step = 0
56
- self._done = False
57
- self._max_steps = max(10, len(self._emails) + 5)
58
- self._action_history: list[TriageAction] = []
59
- self._reward_history: list[float] = []
60
- self._base_score_history: list[float] = []
61
- self._generated_followups = 0
62
- self._max_generated_followups = 4
63
- self._followup_quality_threshold = 0.7
64
- self._configure_runtime_controls()
65
-
66
- def reset(self) -> ResetResult:
67
- """Reset episode state and return the first observation.
68
-
69
- Returns:
70
- ResetResult containing first observation and metadata.
71
- """
72
- self._task_definition = get_task_definition(
73
- self.task_id,
74
- self._episode_index,
75
- self.split,
76
- self.runtime_options,
77
- )
78
- self._scenario_id = str(self._task_definition.get("scenario_id", "unknown"))
79
- self._emails = cast(list[dict[str, object]], self._task_definition.get("emails", []))
80
- self._ground_truth = cast(
81
- list[dict[str, object]], self._task_definition.get("ground_truth", [])
82
- )
83
-
84
- self._current_index = 0
85
- self._current_step = 0
86
- self._done = False
87
- self._max_steps = max(10, len(self._emails) + 5)
88
- self._action_history = []
89
- self._reward_history = []
90
- self._base_score_history = []
91
- self._generated_followups = 0
92
- self._configure_runtime_controls()
93
- self._episode_index += 1
94
-
95
- first_observation = self._build_observation(self._current_index)
96
- return ResetResult(
97
- observation=first_observation,
98
- info={
99
- "task_id": self.task_id,
100
- "scenario_id": self._scenario_id,
101
- "split": self.split,
102
- "step": self._current_step,
103
- },
104
- )
105
-
106
- def step(self, action: TriageAction) -> StepResult:
107
- """Apply an action and return StepResult.
108
-
109
- Args:
110
- action: Proposed triage action.
111
-
112
- Returns:
113
- StepResult with next observation, reward, done flag, and metadata.
114
- """
115
- if self._done:
116
- return StepResult(
117
- observation=self._terminal_observation(),
118
- reward=0.0,
119
- done=True,
120
- info={
121
- "task_id": self.task_id,
122
- "scenario_id": self._scenario_id,
123
- "split": self.split,
124
- "step": self._current_step,
125
- "already_done": True,
126
- },
127
- )
128
-
129
- try:
130
- validated_action = TriageAction.model_validate(action)
131
- except ValidationError as validation_error:
132
- self._current_step += 1
133
- self._reward_history.append(0.0)
134
- self._done = self._current_step >= self._max_steps
135
- return StepResult(
136
- observation=self._build_observation(self._current_index),
137
- reward=0.0,
138
- done=self._done,
139
- info={
140
- "task_id": self.task_id,
141
- "scenario_id": self._scenario_id,
142
- "split": self.split,
143
- "step": self._current_step,
144
- "validation_error": str(validation_error),
145
- },
146
- )
147
-
148
- base_result = self._grade_current_step(validated_action)
149
- base_score = base_result.score
150
-
151
- truth_for_step = (
152
- self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)]
153
- if self._ground_truth
154
- else {}
155
- )
156
- self._maybe_enqueue_follow_up(validated_action, truth_for_step, base_score)
157
-
158
- self._action_history.append(validated_action)
159
- self._base_score_history.append(base_score)
160
- self._current_step += 1
161
-
162
- penalties = self._compute_penalties(validated_action)
163
- trajectory_bonus = self._compute_trajectory_bonus()
164
- final_reward = self._clip_reward(
165
- base_score - (self._current_step * 0.01) + trajectory_bonus - penalties
166
- )
167
-
168
- self._reward_history.append(final_reward)
169
-
170
- if self._current_index < len(self._emails):
171
- self._current_index += 1
172
-
173
- all_emails_processed = self._current_index >= len(self._emails)
174
- self._done = all_emails_processed or self._current_step >= self._max_steps
175
-
176
- next_observation = (
177
- self._terminal_observation()
178
- if self._done
179
- else self._build_observation(self._current_index)
180
- )
181
-
182
- info = {
183
- "task_id": self.task_id,
184
- "scenario_id": self._scenario_id,
185
- "split": self.split,
186
- "step": self._current_step,
187
- "base_score": round(base_score, 4),
188
- "penalties": round(penalties, 4),
189
- "trajectory_bonus": round(trajectory_bonus, 4),
190
- }
191
- return StepResult(
192
- observation=next_observation,
193
- reward=final_reward,
194
- done=self._done,
195
- info=info,
196
- )
197
-
198
- def _maybe_enqueue_follow_up(
199
- self,
200
- action: TriageAction,
201
- truth: dict[str, object],
202
- base_score: float,
203
- ) -> None:
204
- """Insert deterministic escalation follow-up emails for production mode."""
205
- if self.task_id != "task_production":
206
- return
207
- if self._generated_followups >= self._max_generated_followups:
208
- return
209
- if not self._emails:
210
- return
211
-
212
- expected_label = str(truth.get("label", ""))
213
- expected_route = str(truth.get("route_to", "general"))
214
- is_missed_critical = (
215
- expected_label == "urgent"
216
- and (action.label != "urgent" or expected_route not in action.route_to.lower())
217
- )
218
- if not is_missed_critical and base_score >= self._followup_quality_threshold:
219
- return
220
-
221
- source_email = self._emails[min(self._current_index, len(self._emails) - 1)]
222
- source_subject = str(source_email.get("subject", "Inbox incident"))
223
- source_timestamp = str(source_email.get("timestamp", "2026-04-03T00:00:00Z"))
224
-
225
- followup_email = {
226
- "email_id": f"followup-{self._scenario_id}-{self._generated_followups + 1}",
227
- "subject": f"Escalation follow-up: {source_subject}",
228
- "body": (
229
- "Automated escalation triggered because prior triage appears incomplete. "
230
- "Please route to the responsible team and provide a clear summary now."
231
- ),
232
- "sender": "incident-control@acme-enterprise.com",
233
- "timestamp": source_timestamp,
234
- "thread_history": [f"Previous message subject: {source_subject}"],
235
- }
236
- followup_truth = {
237
- "label": "urgent",
238
- "route_to": expected_route,
239
- "priority_weight": min(max(float(truth.get("priority_weight", 1.5)) + 0.2, 1.5), 2.0),
240
- "summary_keywords": ["escalation", "follow-up", expected_route],
241
- }
242
-
243
- insert_at = min(self._current_index + 1, len(self._emails))
244
- self._emails.insert(insert_at, followup_email)
245
- self._ground_truth.insert(insert_at, followup_truth)
246
- self._generated_followups += 1
247
-
248
- def _configure_runtime_controls(self) -> None:
249
- """Apply deterministic runtime control options for production simulator."""
250
- if self.task_id != "task_production":
251
- self._max_generated_followups = 4
252
- self._followup_quality_threshold = 0.7
253
- return
254
-
255
- escalation_mode = str(self.runtime_options.get("escalation_mode", "normal")).lower()
256
- escalation_map = {
257
- "low": (2, 0.55),
258
- "normal": (4, 0.7),
259
- "high": (8, 0.85),
260
- }
261
- max_followups, threshold = escalation_map.get(escalation_mode, escalation_map["normal"])
262
- self._max_generated_followups = max_followups
263
- self._followup_quality_threshold = threshold
264
-
265
- def state(self) -> EnvironmentState:
266
- """Return read-only snapshot of full internal state.
267
-
268
- Returns:
269
- EnvironmentState with progress and history.
270
- """
271
- return EnvironmentState(
272
- task_id=self.task_id,
273
- current_step=self._current_step,
274
- total_steps=self._max_steps,
275
- done=self._done,
276
- action_history=list(self._action_history),
277
- reward_history=list(self._reward_history),
278
- )
279
-
280
- def _build_observation(self, email_index: int) -> EmailObservation:
281
- """Build observation for the email at a given index.
282
-
283
- Args:
284
- email_index: Zero-based email index.
285
-
286
- Returns:
287
- EmailObservation for the selected email or terminal placeholder.
288
- """
289
- if not self._emails:
290
- return self._terminal_observation()
291
-
292
- safe_index = min(max(email_index, 0), len(self._emails) - 1)
293
- email_payload = self._emails[safe_index]
294
-
295
- return EmailObservation(
296
- email_id=str(email_payload.get("email_id", "")),
297
- subject=str(email_payload.get("subject", "")),
298
- body=str(email_payload.get("body", "")),
299
- sender=str(email_payload.get("sender", "")),
300
- timestamp=str(email_payload.get("timestamp", "")),
301
- thread_history=[str(item) for item in email_payload.get("thread_history", [])],
302
- task_id=self.task_id,
303
- step_number=self._current_step,
304
- total_emails=len(self._emails),
305
- )
306
-
307
- def _terminal_observation(self) -> EmailObservation:
308
- """Build terminal observation returned when episode is complete.
309
-
310
- Returns:
311
- Terminal EmailObservation payload.
312
- """
313
- return EmailObservation(
314
- email_id="terminal",
315
- subject="Episode complete",
316
- body="No further emails remain for this task.",
317
- sender="system",
318
- timestamp="",
319
- thread_history=[],
320
- task_id=self.task_id,
321
- step_number=self._current_step,
322
- total_emails=len(self._emails),
323
- )
324
-
325
- def _grade_current_step(self, action: TriageAction) -> RewardResult:
326
- """Select deterministic grader based on task and current progress.
327
-
328
- Args:
329
- action: Validated action for the current step.
330
-
331
- Returns:
332
- RewardResult from task-specific grader.
333
- """
334
- if not self._ground_truth:
335
- return RewardResult(
336
- score=0.0,
337
- breakdown={"missing_ground_truth": 1.0},
338
- feedback="Missing ground truth for task.",
339
- )
340
-
341
- if self.task_id == "task_easy":
342
- truth = self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)]
343
- return grade_easy(action, truth)
344
-
345
- if self.task_id == "task_medium":
346
- truth = self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)]
347
- return grade_medium_step(action, truth)
348
-
349
- truth = self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)]
350
- return grade_hard(action, truth)
351
-
352
- def _compute_penalties(self, action: TriageAction) -> float:
353
- """Compute deterministic penalties according to reward policy.
354
-
355
- Args:
356
- action: Validated action for the step.
357
-
358
- Returns:
359
- Total penalty value for current step.
360
- """
361
- penalty_total = 0.0
362
-
363
- summary_too_short = len(action.summary.strip()) < 10
364
- if action.label == "archive" and summary_too_short:
365
- penalty_total += 0.5
366
-
367
- if self._is_repeated_action_pattern(action):
368
- penalty_total += 0.3
369
-
370
- return penalty_total
371
-
372
- def _compute_trajectory_bonus(self) -> float:
373
- """Return trajectory bonus when episode completion quality is high.
374
-
375
- Returns:
376
- 0.2 when mean base score is above threshold at completion, else 0.0.
377
- """
378
- if not self._base_score_history:
379
- return 0.0
380
-
381
- all_emails_done_after_step = self._current_index + 1 >= len(self._emails)
382
- if not all_emails_done_after_step:
383
- return 0.0
384
-
385
- mean_base = sum(self._base_score_history) / len(self._base_score_history)
386
- return 0.2 if mean_base > 0.8 else 0.0
387
-
388
- def _is_repeated_action_pattern(self, action: TriageAction) -> bool:
389
- """Detect whether same action appears three times consecutively.
390
-
391
- Args:
392
- action: Current action.
393
-
394
- Returns:
395
- True when repeated label and route occur three times in a row.
396
- """
397
- if len(self._action_history) < 2:
398
- return False
399
-
400
- previous_action = self._action_history[-1]
401
- older_action = self._action_history[-2]
402
-
403
- return (
404
- previous_action.label == older_action.label == action.label
405
- and previous_action.route_to.strip().lower()
406
- == older_action.route_to.strip().lower()
407
- == action.route_to.strip().lower()
408
- )
409
-
410
- def _clip_reward(self, reward_value: float) -> float:
411
- """Clip reward to the inclusive range [-1.0, 1.0].
412
-
413
- Args:
414
- reward_value: Raw reward value.
415
-
416
- Returns:
417
- Clipped reward.
418
- """
419
- return max(-1.0, min(1.0, reward_value))