""" Email Triage & Response Environment OpenEnv-compatible environment for agent evaluation. """ import json from typing import Optional, Literal from pydantic import BaseModel, Field class Email(BaseModel): id: str from_: str = Field(..., alias="from") subject: str body: str labels: list[str] = [] replied: bool = False archived: bool = False flagged: bool = False flag_reason: Optional[str] = None reply_body: Optional[str] = None class Config: populate_by_name = True class InboxState(BaseModel): inbox: list[Email] sent: list[dict] = [] step_count: int = 0 class Observation(BaseModel): status: str message: str data: Optional[dict] = None step_count: int = 0 class Action(BaseModel): action: Literal["label", "draft_reply", "archive", "flag", "read", "list_inbox"] email_id: Optional[str] = None priority: Optional[Literal["urgent", "normal", "low"]] = None body: Optional[str] = None reason: Optional[str] = None class StepResult(BaseModel): observation: Observation reward: float done: bool info: dict = {} import os # Load dataset (generated by curate_dataset.py) DATASET_PATH = os.path.join(os.path.dirname(__file__), "data", "emails.json") try: with open(DATASET_PATH, "r", encoding="utf-8") as f: _dataset = json.load(f) except FileNotFoundError: # Fallback to empty if not curated yet, though curate_dataset.py should be run first _dataset = {"task1": {"emails": [], "ground_truth": {}}, "task2": {"emails": []}, "task3": {"emails": [], "ground_truth": {}, "urgent_ids": [], "archive_ids": [], "flag_ids": []}} TASK1_EMAILS = _dataset["task1"]["emails"] TASK1_GROUND_TRUTH = _dataset["task1"].get("ground_truth", {}) TASK2_EMAIL = _dataset["task2"]["emails"][0] if _dataset["task2"]["emails"] else {} TASK3_EMAILS = _dataset["task3"]["emails"] TASK3_GROUND_TRUTH = _dataset["task3"].get("ground_truth", {}) TASK3_URGENT_IDS = set(_dataset["task3"].get("urgent_ids", [])) TASK3_ARCHIVE_IDS = set(_dataset["task3"].get("archive_ids", [])) TASK3_FLAG_IDS = set(_dataset["task3"].get("flag_ids", [])) def grade_task1(state: InboxState) -> float: score = 0.0 for email in state.inbox: gt = TASK1_GROUND_TRUTH.get(email.id) if gt and "urgent" in email.labels and gt == "urgent": score += 0.2 elif gt and "normal" in email.labels and gt == "normal": score += 0.2 elif gt and "low" in email.labels and gt == "low": score += 0.2 return round(min(score, 1.0), 2) def grade_task2(state: InboxState) -> float: score = 0.0 email = next((e for e in state.inbox if e.id == "t2_001"), None) if email is None or not email.replied or not email.reply_body: return 0.0 reply = email.reply_body.lower() issues_covered = 0 if "order" in reply and ("48291" in reply or "order" in reply): issues_covered += 1 if any(w in reply for w in ["refund", "deliver", "shipment", "track"]): issues_covered += 1 if any(w in reply for w in ["compensat", "apologi", "sorry", "inconvenien"]): issues_covered += 1 score += 0.1 * issues_covered # up to 0.3 # +0.3 professional tone professional_signals = ["dear", "sincerely", "regards", "thank you", "we apologize", "we understand", "please", "we will"] rude_signals = ["whatever", "not our fault", "calm down"] tone_score = sum(1 for w in professional_signals if w in reply) rude_penalty = sum(1 for w in rude_signals if w in reply) score += min(0.3, tone_score * 0.05) - (rude_penalty * 0.1) # +0.2 correct recipient / subject handling if email.reply_body and len(email.reply_body) > 50: score += 0.2 # +0.2 no fabricated facts (heuristic: no invented order dates / amounts) fabrication_signals = ["$", "€", "refund amount", "exact date", "tracking number is"] fab_hits = sum(1 for w in fabrication_signals if w in reply) if fab_hits == 0: score += 0.2 return round(max(0.0, min(score, 1.0)), 2) def grade_task3(state: InboxState, penalties: dict) -> float: score = 0.0 email_map = {e.id: e for e in state.inbox} # Priority labels (0.2 per correct, 10 emails = max 2.0 → normalise to 0.5 weight) label_score = 0.0 for eid, gt in TASK3_GROUND_TRUTH.items(): email = email_map.get(eid) if email and gt in email.labels: label_score += 0.2 score += min(label_score, 2.0) * 0.25 # normalise to 0.5 # Replies for urgent emails (max 0.4) reply_scores = [] for eid in TASK3_URGENT_IDS: email = email_map.get(eid) if email and email.replied and email.reply_body: reply_scores.append(min(len(email.reply_body) / 200, 1.0) * 0.1) score += sum(reply_scores) # Archive spam (0.05 each, max 0.1) for eid in TASK3_ARCHIVE_IDS: email = email_map.get(eid) if email and email.archived: score += 0.05 # Flag ambiguous (0.05 each) for eid in TASK3_FLAG_IDS: email = email_map.get(eid) if email and email.flagged: score += 0.05 # Penalties score -= penalties.get("destructive_actions", 0) * 0.1 score -= penalties.get("loop_actions", 0) * 0.05 return round(max(0.0, min(score, 1.0)), 2) # --------------------------------------------------------------------------- # Environment Class # --------------------------------------------------------------------------- class EmailTriageEnv: """OpenEnv-compatible Email Triage environment.""" TASKS = {1, 2, 3} def __init__(self, task: int = 1): assert task in self.TASKS, f"task must be one of {self.TASKS}" self.task = task self._state: Optional[InboxState] = None self._penalties = {"destructive_actions": 0, "loop_actions": 0} self._action_history: list[str] = [] self._done = False # ------------------------------------------------------------------ # OpenEnv interface # ------------------------------------------------------------------ def reset(self) -> Observation: self._penalties = {"destructive_actions": 0, "loop_actions": 0} self._action_history = [] self._done = False if self.task == 1: emails = [Email.model_validate(e) for e in TASK1_EMAILS] elif self.task == 2: emails = [Email.model_validate(TASK2_EMAIL)] else: emails = [Email.model_validate(e) for e in TASK3_EMAILS] self._state = InboxState(inbox=emails) return Observation( status="ok", message=f"Task {self.task} environment reset. Inbox contains {len(emails)} email(s).", data={"task": self.task, "inbox_size": len(emails)}, step_count=0, ) def state(self) -> dict: assert self._state is not None, "Call reset() first." return json.loads(self._state.model_dump_json(by_alias=True)) def step(self, action: Action) -> StepResult: assert self._state is not None, "Call reset() first." if self._done: return StepResult( observation=Observation(status="done", message="Episode already finished.", step_count=self._state.step_count), reward=0.0, done=True, ) self._state.step_count += 1 action_key = f"{action.action}:{action.email_id}" # Loop detection if self._action_history.count(action_key) >= 2: self._penalties["loop_actions"] += 1 self._action_history.append(action_key) obs, reward = self._dispatch(action) obs.step_count = self._state.step_count return StepResult(observation=obs, reward=reward, done=self._done) def score(self) -> float: """Return current cumulative score (0-1).""" assert self._state is not None, "Call reset() first." if self.task == 1: return grade_task1(self._state) elif self.task == 2: return grade_task2(self._state) else: return grade_task3(self._state, self._penalties) # ------------------------------------------------------------------ # Action dispatch # ------------------------------------------------------------------ def _dispatch(self, action: Action): handlers = { "list_inbox": self._act_list_inbox, "read": self._act_read, "label": self._act_label, "draft_reply":self._act_draft_reply, "archive": self._act_archive, "flag": self._act_flag, } handler = handlers.get(action.action) if handler is None: return Observation(status="error", message=f"Unknown action: {action.action}"), 0.0 return handler(action) def _act_list_inbox(self, action: Action): summaries = [ {"id": e.id, "from": e.from_, "subject": e.subject, "labels": e.labels, "replied": e.replied, "archived": e.archived, "flagged": e.flagged} for e in self._state.inbox ] return Observation(status="ok", message="Inbox listed.", data={"emails": summaries}), 0.0 def _act_read(self, action: Action): email = self._find(action.email_id) if email is None: return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0 return Observation( status="ok", message=f"Read email {action.email_id}.", data=json.loads(email.model_dump_json(by_alias=True)), ), 0.0 def _act_label(self, action: Action): email = self._find(action.email_id) if email is None: return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0 if action.priority not in ("urgent", "normal", "low"): return Observation(status="error", message="priority must be urgent | normal | low"), 0.0 # Remove existing priority labels then add new email.labels = [l for l in email.labels if l not in ("urgent", "normal", "low")] email.labels.append(action.priority) incremental = self._incremental_label_reward(email.id, action.priority) return Observation( status="ok", message=f"Labelled {action.email_id} as {action.priority}.", data={"email_id": action.email_id, "priority": action.priority}, ), incremental def _act_draft_reply(self, action: Action): email = self._find(action.email_id) if email is None: return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0 if not action.body or len(action.body.strip()) < 10: return Observation(status="error", message="Reply body too short."), 0.0 email.replied = True email.reply_body = action.body self._state.sent.append({"to": email.from_, "subject": f"Re: {email.subject}", "body": action.body}) return Observation(status="ok", message=f"Reply drafted for {action.email_id}."), 0.0 def _act_archive(self, action: Action): email = self._find(action.email_id) if email is None: return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0 # Penalty if archiving urgent email if "urgent" in email.labels: self._penalties["destructive_actions"] += 1 return Observation( status="warning", message=f"Archived urgent email {action.email_id} — penalty applied.", ), -0.1 email.archived = True return Observation(status="ok", message=f"Email {action.email_id} archived."), 0.0 def _act_flag(self, action: Action): email = self._find(action.email_id) if email is None: return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0 email.flagged = True email.flag_reason = action.reason or "unspecified" return Observation(status="ok", message=f"Email {action.email_id} flagged for human review."), 0.0 # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _find(self, email_id: Optional[str]) -> Optional[Email]: if email_id is None: return None return next((e for e in self._state.inbox if e.id == email_id), None) def _incremental_label_reward(self, email_id: str, priority: str) -> float: """Return +0.2 if label matches ground truth for task 1.""" if self.task == 1: gt = TASK1_GROUND_TRUTH.get(email_id) return 0.2 if gt == priority else 0.0 return 0.0