Imaginephoenix commited on
Commit
d49ad1e
·
verified ·
1 Parent(s): 6199458

Delete environment.py

Browse files
Files changed (1) hide show
  1. environment.py +0 -469
environment.py DELETED
@@ -1,469 +0,0 @@
1
- """Core OpenEnv email triage environment implementation."""
2
-
3
- import os
4
- from typing import cast
5
-
6
- from pydantic import ValidationError
7
-
8
- from graders import SCORE_EPSILON, grade_easy, grade_hard, grade_medium_step
9
- from models import (
10
- EmailObservation,
11
- EnvironmentState,
12
- ResetResult,
13
- RewardResult,
14
- StepResult,
15
- TriageAction,
16
- )
17
- from tasks import get_task_definition
18
-
19
-
20
- class EmailTriageEnv:
21
- """Deterministic email triage environment implementing reset, step, and state."""
22
-
23
- def __init__(
24
- self,
25
- task_id: str,
26
- scenario_index: int = 0,
27
- split: str | None = None,
28
- runtime_options: dict[str, object] | None = None,
29
- ) -> None:
30
- """Initialize environment with a selected task.
31
-
32
- Args:
33
- task_id: Task identifier such as task_easy, task_medium, or task_hard.
34
- scenario_index: Deterministic scenario index within the task pool.
35
- split: Scenario split, either public or private_eval.
36
- runtime_options: Optional deterministic runtime controls for task generation.
37
- """
38
- self.task_id = task_id
39
- self._episode_index = max(0, scenario_index)
40
- self.split = split or os.getenv("OPENENV_EVAL_SPLIT", "public")
41
- self.runtime_options = runtime_options or {}
42
- self._task_definition = get_task_definition(
43
- task_id,
44
- self._episode_index,
45
- self.split,
46
- self.runtime_options,
47
- )
48
- self._scenario_id = str(self._task_definition.get("scenario_id", "unknown"))
49
- self._emails = cast(list[dict[str, object]], self._task_definition.get("emails", []))
50
- self._ground_truth = cast(
51
- list[dict[str, object]], self._task_definition.get("ground_truth", [])
52
- )
53
-
54
- self._current_index = 0
55
- self._current_step = 0
56
- self._done = False
57
- self._max_steps = max(10, len(self._emails) + 5)
58
- self._action_history: list[TriageAction] = []
59
- self._reward_history: list[float] = []
60
- self._base_score_history: list[float] = []
61
- self._generated_followups = 0
62
- self._max_generated_followups = 4
63
- self._followup_quality_threshold = 0.7
64
- self._configure_runtime_controls()
65
-
66
- def reset(self) -> ResetResult:
67
- """Reset episode state and return the first observation.
68
-
69
- Returns:
70
- ResetResult containing first observation and metadata.
71
- """
72
- self._task_definition = get_task_definition(
73
- self.task_id,
74
- self._episode_index,
75
- self.split,
76
- self.runtime_options,
77
- )
78
- self._scenario_id = str(self._task_definition.get("scenario_id", "unknown"))
79
- self._emails = cast(list[dict[str, object]], self._task_definition.get("emails", []))
80
- self._ground_truth = cast(
81
- list[dict[str, object]], self._task_definition.get("ground_truth", [])
82
- )
83
-
84
- self._current_index = 0
85
- self._current_step = 0
86
- self._done = False
87
- self._max_steps = max(10, len(self._emails) + 5)
88
- self._action_history = []
89
- self._reward_history = []
90
- self._base_score_history = []
91
- self._generated_followups = 0
92
- self._configure_runtime_controls()
93
- self._episode_index += 1
94
-
95
- first_observation = self._build_observation(self._current_index)
96
- return ResetResult(
97
- observation=first_observation,
98
- info={
99
- "task_id": self.task_id,
100
- "scenario_id": self._scenario_id,
101
- "split": self.split,
102
- "step": self._current_step,
103
- "emails_total": len(self._emails),
104
- "task_description": str(self._task_definition.get("description", "")),
105
- },
106
- )
107
-
108
- def step(self, action: TriageAction) -> StepResult:
109
- """Apply an action and return StepResult.
110
-
111
- Args:
112
- action: Proposed triage action.
113
-
114
- Returns:
115
- StepResult with next observation, reward, done flag, and metadata.
116
- """
117
- if self._done:
118
- return StepResult(
119
- observation=self._terminal_observation(),
120
- reward=SCORE_EPSILON,
121
- done=True,
122
- info={
123
- "task_id": self.task_id,
124
- "scenario_id": self._scenario_id,
125
- "split": self.split,
126
- "step": self._current_step,
127
- "already_done": True,
128
- },
129
- )
130
-
131
- try:
132
- validated_action = TriageAction.model_validate(action)
133
- except ValidationError as validation_error:
134
- self._current_step += 1
135
- self._reward_history.append(SCORE_EPSILON)
136
- self._done = self._current_step >= self._max_steps
137
- return StepResult(
138
- observation=self._build_observation(self._current_index),
139
- reward=SCORE_EPSILON,
140
- done=self._done,
141
- info={
142
- "task_id": self.task_id,
143
- "scenario_id": self._scenario_id,
144
- "split": self.split,
145
- "step": self._current_step,
146
- "emails_total": len(self._emails),
147
- "emails_processed": self._current_index,
148
- "emails_remaining": max(len(self._emails) - self._current_index, 0),
149
- "validation_error": str(validation_error),
150
- },
151
- )
152
-
153
- base_result = self._grade_current_step(validated_action)
154
- base_score = base_result.score
155
- previous_base_score = self._base_score_history[-1] if self._base_score_history else None
156
- progress_signal = self._compute_progress_signal(base_score, previous_base_score)
157
-
158
- truth_for_step = (
159
- self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)]
160
- if self._ground_truth
161
- else {}
162
- )
163
- self._maybe_enqueue_follow_up(validated_action, truth_for_step, base_score)
164
-
165
- self._action_history.append(validated_action)
166
- self._base_score_history.append(base_score)
167
- self._current_step += 1
168
-
169
- penalties = self._compute_penalties(validated_action)
170
- trajectory_bonus = self._compute_trajectory_bonus()
171
- step_cost = self._compute_step_cost()
172
- final_reward = self._clip_reward(
173
- base_score + progress_signal + trajectory_bonus - penalties - step_cost
174
- )
175
-
176
- self._reward_history.append(final_reward)
177
-
178
- if self._current_index < len(self._emails):
179
- self._current_index += 1
180
-
181
- all_emails_processed = self._current_index >= len(self._emails)
182
- self._done = all_emails_processed or self._current_step >= self._max_steps
183
-
184
- next_observation = (
185
- self._terminal_observation()
186
- if self._done
187
- else self._build_observation(self._current_index)
188
- )
189
-
190
- info = {
191
- "task_id": self.task_id,
192
- "scenario_id": self._scenario_id,
193
- "split": self.split,
194
- "step": self._current_step,
195
- "emails_total": len(self._emails),
196
- "emails_processed": min(self._current_index, len(self._emails)),
197
- "emails_remaining": max(len(self._emails) - self._current_index, 0),
198
- "base_score": round(base_score, 4),
199
- "progress_signal": round(progress_signal, 4),
200
- "step_cost": round(step_cost, 4),
201
- "penalties": round(penalties, 4),
202
- "trajectory_bonus": round(trajectory_bonus, 4),
203
- "grading_feedback": base_result.feedback,
204
- }
205
- for breakdown_key, breakdown_value in base_result.breakdown.items():
206
- if isinstance(breakdown_value, (int, float)):
207
- info[f"grade_{breakdown_key}"] = round(float(breakdown_value), 4)
208
-
209
- return StepResult(
210
- observation=next_observation,
211
- reward=final_reward,
212
- done=self._done,
213
- info=info,
214
- )
215
-
216
- def _maybe_enqueue_follow_up(
217
- self,
218
- action: TriageAction,
219
- truth: dict[str, object],
220
- base_score: float,
221
- ) -> None:
222
- """Insert deterministic escalation follow-up emails for production mode."""
223
- if self.task_id != "task_production":
224
- return
225
- if self._generated_followups >= self._max_generated_followups:
226
- return
227
- if not self._emails:
228
- return
229
-
230
- expected_label = str(truth.get("label", ""))
231
- expected_route = str(truth.get("route_to", "general"))
232
- is_missed_critical = (
233
- expected_label == "urgent"
234
- and (action.label != "urgent" or expected_route not in action.route_to.lower())
235
- )
236
- if not is_missed_critical and base_score >= self._followup_quality_threshold:
237
- return
238
-
239
- source_email = self._emails[min(self._current_index, len(self._emails) - 1)]
240
- source_subject = str(source_email.get("subject", "Inbox incident"))
241
- source_timestamp = str(source_email.get("timestamp", "2026-04-03T00:00:00Z"))
242
-
243
- followup_email = {
244
- "email_id": f"followup-{self._scenario_id}-{self._generated_followups + 1}",
245
- "subject": f"Escalation follow-up: {source_subject}",
246
- "body": (
247
- "Automated escalation triggered because prior triage appears incomplete. "
248
- "Please route to the responsible team and provide a clear summary now."
249
- ),
250
- "sender": "incident-control@acme-enterprise.com",
251
- "timestamp": source_timestamp,
252
- "thread_history": [f"Previous message subject: {source_subject}"],
253
- }
254
- followup_truth = {
255
- "label": "urgent",
256
- "route_to": expected_route,
257
- "priority_weight": min(max(float(truth.get("priority_weight", 1.5)) + 0.2, 1.5), 2.0),
258
- "summary_keywords": ["escalation", "follow-up", expected_route],
259
- }
260
-
261
- insert_at = min(self._current_index + 1, len(self._emails))
262
- self._emails.insert(insert_at, followup_email)
263
- self._ground_truth.insert(insert_at, followup_truth)
264
- self._generated_followups += 1
265
-
266
- def _configure_runtime_controls(self) -> None:
267
- """Apply deterministic runtime control options for production simulator."""
268
- if self.task_id != "task_production":
269
- self._max_generated_followups = 4
270
- self._followup_quality_threshold = 0.7
271
- return
272
-
273
- escalation_mode = str(self.runtime_options.get("escalation_mode", "normal")).lower()
274
- escalation_map = {
275
- "low": (2, 0.55),
276
- "normal": (4, 0.7),
277
- "high": (8, 0.85),
278
- }
279
- max_followups, threshold = escalation_map.get(escalation_mode, escalation_map["normal"])
280
- self._max_generated_followups = max_followups
281
- self._followup_quality_threshold = threshold
282
-
283
- def state(self) -> EnvironmentState:
284
- """Return read-only snapshot of full internal state.
285
-
286
- Returns:
287
- EnvironmentState with progress and history.
288
- """
289
- return EnvironmentState(
290
- task_id=self.task_id,
291
- current_step=self._current_step,
292
- total_steps=self._max_steps,
293
- done=self._done,
294
- action_history=list(self._action_history),
295
- reward_history=list(self._reward_history),
296
- )
297
-
298
- def _build_observation(self, email_index: int) -> EmailObservation:
299
- """Build observation for the email at a given index.
300
-
301
- Args:
302
- email_index: Zero-based email index.
303
-
304
- Returns:
305
- EmailObservation for the selected email or terminal placeholder.
306
- """
307
- if not self._emails:
308
- return self._terminal_observation()
309
-
310
- safe_index = min(max(email_index, 0), len(self._emails) - 1)
311
- email_payload = self._emails[safe_index]
312
-
313
- return EmailObservation(
314
- email_id=str(email_payload.get("email_id", "")),
315
- subject=str(email_payload.get("subject", "")),
316
- body=str(email_payload.get("body", "")),
317
- sender=str(email_payload.get("sender", "")),
318
- timestamp=str(email_payload.get("timestamp", "")),
319
- thread_history=[str(item) for item in email_payload.get("thread_history", [])],
320
- task_id=self.task_id,
321
- step_number=self._current_step,
322
- total_emails=len(self._emails),
323
- )
324
-
325
- def _terminal_observation(self) -> EmailObservation:
326
- """Build terminal observation returned when episode is complete.
327
-
328
- Returns:
329
- Terminal EmailObservation payload.
330
- """
331
- return EmailObservation(
332
- email_id="terminal",
333
- subject="Episode complete",
334
- body="No further emails remain for this task.",
335
- sender="system",
336
- timestamp="",
337
- thread_history=[],
338
- task_id=self.task_id,
339
- step_number=self._current_step,
340
- total_emails=len(self._emails),
341
- )
342
-
343
- def _grade_current_step(self, action: TriageAction) -> RewardResult:
344
- """Select deterministic grader based on task and current progress.
345
-
346
- Args:
347
- action: Validated action for the current step.
348
-
349
- Returns:
350
- RewardResult from task-specific grader.
351
- """
352
- if not self._ground_truth:
353
- return RewardResult(
354
- score=SCORE_EPSILON,
355
- breakdown={"missing_ground_truth": 1.0 - SCORE_EPSILON},
356
- feedback="Missing ground truth for task.",
357
- )
358
-
359
- if self.task_id == "task_easy":
360
- truth = self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)]
361
- return grade_easy(action, truth)
362
-
363
- if self.task_id == "task_medium":
364
- truth = self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)]
365
- return grade_medium_step(action, truth)
366
-
367
- truth = self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)]
368
- return grade_hard(action, truth)
369
-
370
- def _compute_penalties(self, action: TriageAction) -> float:
371
- """Compute deterministic penalties according to reward policy.
372
-
373
- Args:
374
- action: Validated action for the step.
375
-
376
- Returns:
377
- Total penalty value for current step.
378
- """
379
- penalty_total = 0.0
380
-
381
- summary_too_short = len(action.summary.strip()) < 10
382
- if action.label == "archive" and summary_too_short:
383
- penalty_total += 0.5
384
-
385
- if self._is_repeated_action_pattern(action):
386
- penalty_total += 0.3
387
-
388
- return penalty_total
389
-
390
- def _compute_progress_signal(
391
- self,
392
- base_score: float,
393
- previous_base_score: float | None,
394
- ) -> float:
395
- """Compute dense partial-progress reward independent of final completion.
396
-
397
- Args:
398
- base_score: Current-step base grade in [0.0, 1.0].
399
- previous_base_score: Previous step base grade when available.
400
-
401
- Returns:
402
- Small positive/negative signal reflecting progress and quality trend.
403
- """
404
- total_emails = max(len(self._emails), 1)
405
- progress_ratio = min(1.0, (self._current_index + 1) / total_emails)
406
-
407
- completion_signal = 0.05 * progress_ratio
408
- quality_signal = 0.05 * self._clip_reward(base_score)
409
-
410
- trend_signal = 0.0
411
- if previous_base_score is not None:
412
- delta = base_score - previous_base_score
413
- trend_signal = max(-0.02, min(0.03, delta * 0.1))
414
-
415
- return completion_signal + quality_signal + trend_signal
416
-
417
- def _compute_step_cost(self) -> float:
418
- """Return a gentle efficiency cost that grows with episode length."""
419
- normalized_step = self._current_step / max(self._max_steps, 1)
420
- return 0.005 + (0.01 * normalized_step)
421
-
422
- def _compute_trajectory_bonus(self) -> float:
423
- """Return trajectory bonus when episode completion quality is high.
424
-
425
- Returns:
426
- 0.2 when mean base score is above threshold at completion, else 0.0.
427
- """
428
- if not self._base_score_history:
429
- return 0.0
430
-
431
- all_emails_done_after_step = self._current_index + 1 >= len(self._emails)
432
- if not all_emails_done_after_step:
433
- return 0.0
434
-
435
- mean_base = sum(self._base_score_history) / len(self._base_score_history)
436
- return 0.2 if mean_base > 0.8 else 0.0
437
-
438
- def _is_repeated_action_pattern(self, action: TriageAction) -> bool:
439
- """Detect whether same action appears three times consecutively.
440
-
441
- Args:
442
- action: Current action.
443
-
444
- Returns:
445
- True when repeated label and route occur three times in a row.
446
- """
447
- if len(self._action_history) < 2:
448
- return False
449
-
450
- previous_action = self._action_history[-1]
451
- older_action = self._action_history[-2]
452
-
453
- return (
454
- previous_action.label == older_action.label == action.label
455
- and previous_action.route_to.strip().lower()
456
- == older_action.route_to.strip().lower()
457
- == action.route_to.strip().lower()
458
- )
459
-
460
- def _clip_reward(self, reward_value: float) -> float:
461
- """Clip reward to the strict range [0.0, 1.0].
462
-
463
- Args:
464
- reward_value: Raw reward value.
465
-
466
- Returns:
467
- Clipped reward.
468
- """
469
- return max(SCORE_EPSILON, min(1.0 - SCORE_EPSILON, reward_value))