openenv / environment.py
Imaginephoenix's picture
Upload 15 files
02e973e verified
raw
history blame
15.1 kB
"""Core OpenEnv email triage environment implementation."""
import os
from typing import cast
from pydantic import ValidationError
from graders import grade_easy, grade_hard, grade_medium_step
from models import (
EmailObservation,
EnvironmentState,
ResetResult,
RewardResult,
StepResult,
TriageAction,
)
from tasks import get_task_definition
class EmailTriageEnv:
"""Deterministic email triage environment implementing reset, step, and state."""
def __init__(
self,
task_id: str,
scenario_index: int = 0,
split: str | None = None,
runtime_options: dict[str, object] | None = None,
) -> None:
"""Initialize environment with a selected task.
Args:
task_id: Task identifier such as task_easy, task_medium, or task_hard.
scenario_index: Deterministic scenario index within the task pool.
split: Scenario split, either public or private_eval.
runtime_options: Optional deterministic runtime controls for task generation.
"""
self.task_id = task_id
self._episode_index = max(0, scenario_index)
self.split = split or os.getenv("OPENENV_EVAL_SPLIT", "public")
self.runtime_options = runtime_options or {}
self._task_definition = get_task_definition(
task_id,
self._episode_index,
self.split,
self.runtime_options,
)
self._scenario_id = str(self._task_definition.get("scenario_id", "unknown"))
self._emails = cast(list[dict[str, object]], self._task_definition.get("emails", []))
self._ground_truth = cast(
list[dict[str, object]], self._task_definition.get("ground_truth", [])
)
self._current_index = 0
self._current_step = 0
self._done = False
self._max_steps = max(10, len(self._emails) + 5)
self._action_history: list[TriageAction] = []
self._reward_history: list[float] = []
self._base_score_history: list[float] = []
self._generated_followups = 0
self._max_generated_followups = 4
self._followup_quality_threshold = 0.7
self._configure_runtime_controls()
def reset(self) -> ResetResult:
"""Reset episode state and return the first observation.
Returns:
ResetResult containing first observation and metadata.
"""
self._task_definition = get_task_definition(
self.task_id,
self._episode_index,
self.split,
self.runtime_options,
)
self._scenario_id = str(self._task_definition.get("scenario_id", "unknown"))
self._emails = cast(list[dict[str, object]], self._task_definition.get("emails", []))
self._ground_truth = cast(
list[dict[str, object]], self._task_definition.get("ground_truth", [])
)
self._current_index = 0
self._current_step = 0
self._done = False
self._max_steps = max(10, len(self._emails) + 5)
self._action_history = []
self._reward_history = []
self._base_score_history = []
self._generated_followups = 0
self._configure_runtime_controls()
self._episode_index += 1
first_observation = self._build_observation(self._current_index)
return ResetResult(
observation=first_observation,
info={
"task_id": self.task_id,
"scenario_id": self._scenario_id,
"split": self.split,
"step": self._current_step,
},
)
def step(self, action: TriageAction) -> StepResult:
"""Apply an action and return StepResult.
Args:
action: Proposed triage action.
Returns:
StepResult with next observation, reward, done flag, and metadata.
"""
if self._done:
return StepResult(
observation=self._terminal_observation(),
reward=0.0,
done=True,
info={
"task_id": self.task_id,
"scenario_id": self._scenario_id,
"split": self.split,
"step": self._current_step,
"already_done": True,
},
)
try:
validated_action = TriageAction.model_validate(action)
except ValidationError as validation_error:
self._current_step += 1
self._reward_history.append(0.0)
self._done = self._current_step >= self._max_steps
return StepResult(
observation=self._build_observation(self._current_index),
reward=0.0,
done=self._done,
info={
"task_id": self.task_id,
"scenario_id": self._scenario_id,
"split": self.split,
"step": self._current_step,
"validation_error": str(validation_error),
},
)
base_result = self._grade_current_step(validated_action)
base_score = base_result.score
truth_for_step = (
self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)]
if self._ground_truth
else {}
)
self._maybe_enqueue_follow_up(validated_action, truth_for_step, base_score)
self._action_history.append(validated_action)
self._base_score_history.append(base_score)
self._current_step += 1
penalties = self._compute_penalties(validated_action)
trajectory_bonus = self._compute_trajectory_bonus()
final_reward = self._clip_reward(
base_score - (self._current_step * 0.01) + trajectory_bonus - penalties
)
self._reward_history.append(final_reward)
if self._current_index < len(self._emails):
self._current_index += 1
all_emails_processed = self._current_index >= len(self._emails)
self._done = all_emails_processed or self._current_step >= self._max_steps
next_observation = (
self._terminal_observation()
if self._done
else self._build_observation(self._current_index)
)
info = {
"task_id": self.task_id,
"scenario_id": self._scenario_id,
"split": self.split,
"step": self._current_step,
"base_score": round(base_score, 4),
"penalties": round(penalties, 4),
"trajectory_bonus": round(trajectory_bonus, 4),
}
return StepResult(
observation=next_observation,
reward=final_reward,
done=self._done,
info=info,
)
def _maybe_enqueue_follow_up(
self,
action: TriageAction,
truth: dict[str, object],
base_score: float,
) -> None:
"""Insert deterministic escalation follow-up emails for production mode."""
if self.task_id != "task_production":
return
if self._generated_followups >= self._max_generated_followups:
return
if not self._emails:
return
expected_label = str(truth.get("label", ""))
expected_route = str(truth.get("route_to", "general"))
is_missed_critical = (
expected_label == "urgent"
and (action.label != "urgent" or expected_route not in action.route_to.lower())
)
if not is_missed_critical and base_score >= self._followup_quality_threshold:
return
source_email = self._emails[min(self._current_index, len(self._emails) - 1)]
source_subject = str(source_email.get("subject", "Inbox incident"))
source_timestamp = str(source_email.get("timestamp", "2026-04-03T00:00:00Z"))
followup_email = {
"email_id": f"followup-{self._scenario_id}-{self._generated_followups + 1}",
"subject": f"Escalation follow-up: {source_subject}",
"body": (
"Automated escalation triggered because prior triage appears incomplete. "
"Please route to the responsible team and provide a clear summary now."
),
"sender": "incident-control@acme-enterprise.com",
"timestamp": source_timestamp,
"thread_history": [f"Previous message subject: {source_subject}"],
}
followup_truth = {
"label": "urgent",
"route_to": expected_route,
"priority_weight": min(max(float(truth.get("priority_weight", 1.5)) + 0.2, 1.5), 2.0),
"summary_keywords": ["escalation", "follow-up", expected_route],
}
insert_at = min(self._current_index + 1, len(self._emails))
self._emails.insert(insert_at, followup_email)
self._ground_truth.insert(insert_at, followup_truth)
self._generated_followups += 1
def _configure_runtime_controls(self) -> None:
"""Apply deterministic runtime control options for production simulator."""
if self.task_id != "task_production":
self._max_generated_followups = 4
self._followup_quality_threshold = 0.7
return
escalation_mode = str(self.runtime_options.get("escalation_mode", "normal")).lower()
escalation_map = {
"low": (2, 0.55),
"normal": (4, 0.7),
"high": (8, 0.85),
}
max_followups, threshold = escalation_map.get(escalation_mode, escalation_map["normal"])
self._max_generated_followups = max_followups
self._followup_quality_threshold = threshold
def state(self) -> EnvironmentState:
"""Return read-only snapshot of full internal state.
Returns:
EnvironmentState with progress and history.
"""
return EnvironmentState(
task_id=self.task_id,
current_step=self._current_step,
total_steps=self._max_steps,
done=self._done,
action_history=list(self._action_history),
reward_history=list(self._reward_history),
)
def _build_observation(self, email_index: int) -> EmailObservation:
"""Build observation for the email at a given index.
Args:
email_index: Zero-based email index.
Returns:
EmailObservation for the selected email or terminal placeholder.
"""
if not self._emails:
return self._terminal_observation()
safe_index = min(max(email_index, 0), len(self._emails) - 1)
email_payload = self._emails[safe_index]
return EmailObservation(
email_id=str(email_payload.get("email_id", "")),
subject=str(email_payload.get("subject", "")),
body=str(email_payload.get("body", "")),
sender=str(email_payload.get("sender", "")),
timestamp=str(email_payload.get("timestamp", "")),
thread_history=[str(item) for item in email_payload.get("thread_history", [])],
task_id=self.task_id,
step_number=self._current_step,
total_emails=len(self._emails),
)
def _terminal_observation(self) -> EmailObservation:
"""Build terminal observation returned when episode is complete.
Returns:
Terminal EmailObservation payload.
"""
return EmailObservation(
email_id="terminal",
subject="Episode complete",
body="No further emails remain for this task.",
sender="system",
timestamp="",
thread_history=[],
task_id=self.task_id,
step_number=self._current_step,
total_emails=len(self._emails),
)
def _grade_current_step(self, action: TriageAction) -> RewardResult:
"""Select deterministic grader based on task and current progress.
Args:
action: Validated action for the current step.
Returns:
RewardResult from task-specific grader.
"""
if not self._ground_truth:
return RewardResult(
score=0.0,
breakdown={"missing_ground_truth": 1.0},
feedback="Missing ground truth for task.",
)
if self.task_id == "task_easy":
truth = self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)]
return grade_easy(action, truth)
if self.task_id == "task_medium":
truth = self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)]
return grade_medium_step(action, truth)
truth = self._ground_truth[min(self._current_index, len(self._ground_truth) - 1)]
return grade_hard(action, truth)
def _compute_penalties(self, action: TriageAction) -> float:
"""Compute deterministic penalties according to reward policy.
Args:
action: Validated action for the step.
Returns:
Total penalty value for current step.
"""
penalty_total = 0.0
summary_too_short = len(action.summary.strip()) < 10
if action.label == "archive" and summary_too_short:
penalty_total += 0.5
if self._is_repeated_action_pattern(action):
penalty_total += 0.3
return penalty_total
def _compute_trajectory_bonus(self) -> float:
"""Return trajectory bonus when episode completion quality is high.
Returns:
0.2 when mean base score is above threshold at completion, else 0.0.
"""
if not self._base_score_history:
return 0.0
all_emails_done_after_step = self._current_index + 1 >= len(self._emails)
if not all_emails_done_after_step:
return 0.0
mean_base = sum(self._base_score_history) / len(self._base_score_history)
return 0.2 if mean_base > 0.8 else 0.0
def _is_repeated_action_pattern(self, action: TriageAction) -> bool:
"""Detect whether same action appears three times consecutively.
Args:
action: Current action.
Returns:
True when repeated label and route occur three times in a row.
"""
if len(self._action_history) < 2:
return False
previous_action = self._action_history[-1]
older_action = self._action_history[-2]
return (
previous_action.label == older_action.label == action.label
and previous_action.route_to.strip().lower()
== older_action.route_to.strip().lower()
== action.route_to.strip().lower()
)
def _clip_reward(self, reward_value: float) -> float:
"""Clip reward to the inclusive range [-1.0, 1.0].
Args:
reward_value: Raw reward value.
Returns:
Clipped reward.
"""
return max(-1.0, min(1.0, reward_value))