Spaces:
Sleeping
Sleeping
| """Core OpenEnv email triage environment implementation.""" | |
| import os | |
| from typing import cast | |
| from pydantic import ValidationError | |
| from graders import SCORE_EPSILON, grade_easy, grade_hard, grade_medium_step | |
| from models import ( | |
| EmailObservation, | |
| EnvironmentState, | |
| ResetResult, | |
| RewardResult, | |
| StepResult, | |
| TriageAction, | |
| ) | |
| from tasks import get_task_definition | |
| class EmailTriageEnv: | |
| """Deterministic email triage environment implementing reset, step, and state.""" | |
| def __init__( | |
| self, | |
| task_id: str, | |
| scenario_index: int = 0, | |
| split: str | None = None, | |
| runtime_options: dict[str, object] | None = None, | |
| ) -> None: | |
| """Initialize environment with a selected task. | |
| Args: | |
| task_id: Task identifier such as task_easy, task_medium, or task_hard. | |
| scenario_index: Deterministic scenario index within the task pool. | |
| split: Scenario split, either public or private_eval. | |
| runtime_options: Optional deterministic runtime controls for task generation. | |
| """ | |
| self.task_id = task_id | |
| self._episode_index = max(0, scenario_index) | |
| self.split = split or os.getenv("OPENENV_EVAL_SPLIT", "public") | |
| self.runtime_options = runtime_options or {} | |
| self._task_definition = get_task_definition( | |
| task_id, | |
| self._episode_index, | |
| self.split, | |
| self.runtime_options, | |
| ) | |
| self._scenario_id = str(self._task_definition.get("scenario_id", "unknown")) | |
| self._emails = cast(list[dict[str, object]], self._task_definition.get("emails", [])) | |
| self._ground_truth = cast( | |
| list[dict[str, object]], self._task_definition.get("ground_truth", []) | |
| ) | |
| self._current_index = 0 | |
| self._current_step = 0 | |
| self._done = False | |
| self._max_steps = max(10, len(self._emails) + 5) | |
| self._action_history: list[TriageAction] = [] | |
| self._reward_history: list[float] = [] | |
| self._base_score_history: list[float] = [] | |
| self._generated_followups = 0 | |
| self._max_generated_followups = 4 | |
| self._followup_quality_threshold = 0.7 | |
| self._configure_runtime_controls() | |
| def reset(self) -> ResetResult: | |
| """Reset episode state and return the first observation. | |
| Returns: | |
| ResetResult containing first observation and metadata. | |
| """ | |
| self._task_definition = get_task_definition( | |
| self.task_id, | |
| self._episode_index, | |
| self.split, | |
| self.runtime_options, | |
| ) | |
| self._scenario_id = str(self._task_definition.get("scenario_id", "unknown")) | |
| self._emails = cast(list[dict[str, object]], self._task_definition.get("emails", [])) | |
| self._ground_truth = cast( | |
| list[dict[str, object]], self._task_definition.get("ground_truth", []) | |
| ) | |
| self._current_index = 0 | |
| self._current_step = 0 | |
| self._done = False | |
| self._max_steps = max(10, len(self._emails) + 5) | |
| self._action_history = [] | |
| self._reward_history = [] | |
| self._base_score_history = [] | |
| self._generated_followups = 0 | |
| self._configure_runtime_controls() | |
| self._episode_index += 1 | |
| first_observation = self._build_observation(self._current_index) | |
| return ResetResult( | |
| observation=first_observation, | |
| info={ | |
| "task_id": self.task_id, | |
| "scenario_id": self._scenario_id, | |
| "split": self.split, | |
| "step": self._current_step, | |
| "emails_total": len(self._emails), | |
| "task_description": str(self._task_definition.get("description", "")), | |
| }, | |
| ) | |
| def step(self, action: TriageAction) -> StepResult: | |
| """Apply an action and return StepResult. | |
| Args: | |
| action: Proposed triage action. | |
| Returns: | |
| StepResult with next observation, reward, done flag, and metadata. | |
| """ | |
| if self._done: | |
| return StepResult( | |
| observation=self._terminal_observation(), | |
| reward=SCORE_EPSILON, | |
| done=True, | |
| info={ | |
| "task_id": self.task_id, | |
| "scenario_id": self._scenario_id, | |
| "split": self.split, | |
| "step": self._current_step, | |
| "already_done": True, | |
| }, | |
| ) | |
| try: | |
| validated_action = TriageAction.model_validate(action) | |
| except ValidationError as validation_error: | |
| self._current_step += 1 | |
| self._reward_history.append(SCORE_EPSILON) | |
| self._done = self._current_step >= self._max_steps | |
| return StepResult( | |
| observation=self._build_observation(self._current_index), | |
| reward=SCORE_EPSILON, | |
| done=self._done, | |
| info={ | |
| "task_id": self.task_id, | |
| "scenario_id": self._scenario_id, | |
| "split": self.split, | |
| "step": self._current_step, | |
| "emails_total": len(self._emails), | |
| "emails_processed": self._current_index, | |
| "emails_remaining": max(len(self._emails) - self._current_index, 0), | |
| "validation_error": str(validation_error), | |
| }, | |
| ) | |
| base_result = self._grade_current_step(validated_action) | |
| base_score = base_result.score | |
| previous_base_score = self._base_score_history[-1] if self._base_score_history else None | |
| progress_signal = self._compute_progress_signal(base_score, previous_base_score) | |
| truth_for_step = ( | |
| self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)] | |
| if self._ground_truth | |
| else {} | |
| ) | |
| self._maybe_enqueue_follow_up(validated_action, truth_for_step, base_score) | |
| self._action_history.append(validated_action) | |
| self._base_score_history.append(base_score) | |
| self._current_step += 1 | |
| penalties = self._compute_penalties(validated_action) | |
| trajectory_bonus = self._compute_trajectory_bonus() | |
| step_cost = self._compute_step_cost() | |
| final_reward = self._clip_reward( | |
| base_score + progress_signal + trajectory_bonus - penalties - step_cost | |
| ) | |
| self._reward_history.append(final_reward) | |
| if self._current_index < len(self._emails): | |
| self._current_index += 1 | |
| all_emails_processed = self._current_index >= len(self._emails) | |
| self._done = all_emails_processed or self._current_step >= self._max_steps | |
| next_observation = ( | |
| self._terminal_observation() | |
| if self._done | |
| else self._build_observation(self._current_index) | |
| ) | |
| info = { | |
| "task_id": self.task_id, | |
| "scenario_id": self._scenario_id, | |
| "split": self.split, | |
| "step": self._current_step, | |
| "emails_total": len(self._emails), | |
| "emails_processed": min(self._current_index, len(self._emails)), | |
| "emails_remaining": max(len(self._emails) - self._current_index, 0), | |
| "base_score": float(base_score), | |
| "progress_signal": float(progress_signal), | |
| "step_cost": float(step_cost), | |
| "penalties": float(penalties), | |
| "trajectory_bonus": float(trajectory_bonus), | |
| "grading_feedback": base_result.feedback, | |
| } | |
| for breakdown_key, breakdown_value in base_result.breakdown.items(): | |
| if isinstance(breakdown_value, (int, float)): | |
| info[f"grade_{breakdown_key}"] = float(breakdown_value) | |
| return StepResult( | |
| observation=next_observation, | |
| reward=final_reward, | |
| done=self._done, | |
| info=info, | |
| ) | |
| def _maybe_enqueue_follow_up( | |
| self, | |
| action: TriageAction, | |
| truth: dict[str, object], | |
| base_score: float, | |
| ) -> None: | |
| """Insert deterministic escalation follow-up emails for production mode.""" | |
| if self.task_id != "task_production": | |
| return | |
| if self._generated_followups >= self._max_generated_followups: | |
| return | |
| if not self._emails: | |
| return | |
| expected_label = str(truth.get("label", "")) | |
| expected_route = str(truth.get("route_to", "general")) | |
| is_missed_critical = ( | |
| expected_label == "urgent" | |
| and (action.label != "urgent" or expected_route not in action.route_to.lower()) | |
| ) | |
| if not is_missed_critical and base_score >= self._followup_quality_threshold: | |
| return | |
| source_email = self._emails[min(self._current_index, len(self._emails) - 1)] | |
| source_subject = str(source_email.get("subject", "Inbox incident")) | |
| source_timestamp = str(source_email.get("timestamp", "2026-04-03T00:00:00Z")) | |
| followup_email = { | |
| "email_id": f"followup-{self._scenario_id}-{self._generated_followups + 1}", | |
| "subject": f"Escalation follow-up: {source_subject}", | |
| "body": ( | |
| "Automated escalation triggered because prior triage appears incomplete. " | |
| "Please route to the responsible team and provide a clear summary now." | |
| ), | |
| "sender": "incident-control@acme-enterprise.com", | |
| "timestamp": source_timestamp, | |
| "thread_history": [f"Previous message subject: {source_subject}"], | |
| } | |
| followup_truth = { | |
| "label": "urgent", | |
| "route_to": expected_route, | |
| "priority_weight": min(max(float(truth.get("priority_weight", 1.5)) + 0.2, 1.5), 2.0), | |
| "summary_keywords": ["escalation", "follow-up", expected_route], | |
| } | |
| insert_at = min(self._current_index + 1, len(self._emails)) | |
| self._emails.insert(insert_at, followup_email) | |
| self._ground_truth.insert(insert_at, followup_truth) | |
| self._generated_followups += 1 | |
| def _configure_runtime_controls(self) -> None: | |
| """Apply deterministic runtime control options for production simulator.""" | |
| if self.task_id != "task_production": | |
| self._max_generated_followups = 4 | |
| self._followup_quality_threshold = 0.7 | |
| return | |
| escalation_mode = str(self.runtime_options.get("escalation_mode", "normal")).lower() | |
| escalation_map = { | |
| "low": (2, 0.55), | |
| "normal": (4, 0.7), | |
| "high": (8, 0.85), | |
| } | |
| max_followups, threshold = escalation_map.get(escalation_mode, escalation_map["normal"]) | |
| self._max_generated_followups = max_followups | |
| self._followup_quality_threshold = threshold | |
| def state(self) -> EnvironmentState: | |
| """Return read-only snapshot of full internal state. | |
| Returns: | |
| EnvironmentState with progress and history. | |
| """ | |
| return EnvironmentState( | |
| task_id=self.task_id, | |
| current_step=self._current_step, | |
| total_steps=self._max_steps, | |
| done=self._done, | |
| action_history=list(self._action_history), | |
| reward_history=list(self._reward_history), | |
| ) | |
| def _build_observation(self, email_index: int) -> EmailObservation: | |
| """Build observation for the email at a given index. | |
| Args: | |
| email_index: Zero-based email index. | |
| Returns: | |
| EmailObservation for the selected email or terminal placeholder. | |
| """ | |
| if not self._emails: | |
| return self._terminal_observation() | |
| safe_index = min(max(email_index, 0), len(self._emails) - 1) | |
| email_payload = self._emails[safe_index] | |
| return EmailObservation( | |
| email_id=str(email_payload.get("email_id", "")), | |
| subject=str(email_payload.get("subject", "")), | |
| body=str(email_payload.get("body", "")), | |
| sender=str(email_payload.get("sender", "")), | |
| timestamp=str(email_payload.get("timestamp", "")), | |
| thread_history=[str(item) for item in email_payload.get("thread_history", [])], | |
| task_id=self.task_id, | |
| step_number=self._current_step, | |
| total_emails=len(self._emails), | |
| ) | |
| def _terminal_observation(self) -> EmailObservation: | |
| """Build terminal observation returned when episode is complete. | |
| Returns: | |
| Terminal EmailObservation payload. | |
| """ | |
| return EmailObservation( | |
| email_id="terminal", | |
| subject="Episode complete", | |
| body="No further emails remain for this task.", | |
| sender="system", | |
| timestamp="", | |
| thread_history=[], | |
| task_id=self.task_id, | |
| step_number=self._current_step, | |
| total_emails=len(self._emails), | |
| ) | |
| def _grade_current_step(self, action: TriageAction) -> RewardResult: | |
| """Select deterministic grader based on task and current progress. | |
| Args: | |
| action: Validated action for the current step. | |
| Returns: | |
| RewardResult from task-specific grader. | |
| """ | |
| if not self._ground_truth: | |
| return RewardResult( | |
| score=SCORE_EPSILON, | |
| breakdown={"missing_ground_truth": 1.0 - SCORE_EPSILON}, | |
| feedback="Missing ground truth for task.", | |
| ) | |
| if self.task_id == "task_easy": | |
| truth = self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)] | |
| return grade_easy(action, truth) | |
| if self.task_id == "task_medium": | |
| truth = self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)] | |
| return grade_medium_step(action, truth) | |
| truth = self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)] | |
| return grade_hard(action, truth) | |
| def _compute_penalties(self, action: TriageAction) -> float: | |
| """Compute deterministic penalties according to reward policy. | |
| Args: | |
| action: Validated action for the step. | |
| Returns: | |
| Total penalty value for current step. | |
| """ | |
| penalty_total = 0.0 | |
| summary_too_short = len(action.summary.strip()) < 10 | |
| if action.label == "archive" and summary_too_short: | |
| penalty_total += 0.5 | |
| if self._is_repeated_action_pattern(action): | |
| penalty_total += 0.3 | |
| return penalty_total | |
| def _compute_progress_signal( | |
| self, | |
| base_score: float, | |
| previous_base_score: float | None, | |
| ) -> float: | |
| """Compute dense partial-progress reward independent of final completion. | |
| Args: | |
| base_score: Current-step base grade in [0.0, 1.0]. | |
| previous_base_score: Previous step base grade when available. | |
| Returns: | |
| Small positive/negative signal reflecting progress and quality trend. | |
| """ | |
| total_emails = max(len(self._emails), 1) | |
| progress_ratio = min(1.0, (self._current_index + 1) / total_emails) | |
| completion_signal = 0.05 * progress_ratio | |
| quality_signal = 0.05 * self._clip_reward(base_score) | |
| trend_signal = 0.0 | |
| if previous_base_score is not None: | |
| delta = base_score - previous_base_score | |
| trend_signal = max(-0.02, min(0.03, delta * 0.1)) | |
| return completion_signal + quality_signal + trend_signal | |
| def _compute_step_cost(self) -> float: | |
| """Return a gentle efficiency cost that grows with episode length.""" | |
| normalized_step = self._current_step / max(self._max_steps, 1) | |
| return 0.005 + (0.01 * normalized_step) | |
| def _compute_trajectory_bonus(self) -> float: | |
| """Return trajectory bonus when episode completion quality is high. | |
| Returns: | |
| 0.2 when mean base score is above threshold at completion, else 0.0. | |
| """ | |
| if not self._base_score_history: | |
| return 0.0 | |
| all_emails_done_after_step = self._current_index + 1 >= len(self._emails) | |
| if not all_emails_done_after_step: | |
| return 0.0 | |
| mean_base = sum(self._base_score_history) / len(self._base_score_history) | |
| return 0.2 if mean_base > 0.8 else 0.0 | |
| def _is_repeated_action_pattern(self, action: TriageAction) -> bool: | |
| """Detect whether same action appears three times consecutively. | |
| Args: | |
| action: Current action. | |
| Returns: | |
| True when repeated label and route occur three times in a row. | |
| """ | |
| if len(self._action_history) < 2: | |
| return False | |
| previous_action = self._action_history[-1] | |
| older_action = self._action_history[-2] | |
| return ( | |
| previous_action.label == older_action.label == action.label | |
| and previous_action.route_to.strip().lower() | |
| == older_action.route_to.strip().lower() | |
| == action.route_to.strip().lower() | |
| ) | |
| def _clip_reward(self, reward_value: float) -> float: | |
| """Clip reward to the strict range (0.0, 1.0). | |
| Args: | |
| reward_value: Raw reward value. | |
| Returns: | |
| Clipped reward. | |
| """ | |
| return max(SCORE_EPSILON, min(1.0 - SCORE_EPSILON, reward_value)) | |