Spaces:
Sleeping
Sleeping
Dev Shah
fix: use Pydantic model_validate instead of unpacking with aliases to resolve Post Reset failure
4ed3421 | """ | |
| Email Triage & Response Environment | |
| OpenEnv-compatible environment for agent evaluation. | |
| """ | |
| import json | |
| from typing import Optional, Literal | |
| from pydantic import BaseModel, Field | |
| class Email(BaseModel): | |
| id: str | |
| from_: str = Field(..., alias="from") | |
| subject: str | |
| body: str | |
| labels: list[str] = [] | |
| replied: bool = False | |
| archived: bool = False | |
| flagged: bool = False | |
| flag_reason: Optional[str] = None | |
| reply_body: Optional[str] = None | |
| class Config: | |
| populate_by_name = True | |
| class InboxState(BaseModel): | |
| inbox: list[Email] | |
| sent: list[dict] = [] | |
| step_count: int = 0 | |
| class Observation(BaseModel): | |
| status: str | |
| message: str | |
| data: Optional[dict] = None | |
| step_count: int = 0 | |
| class Action(BaseModel): | |
| action: Literal["label", "draft_reply", "archive", "flag", "read", "list_inbox"] | |
| email_id: Optional[str] = None | |
| priority: Optional[Literal["urgent", "normal", "low"]] = None | |
| body: Optional[str] = None | |
| reason: Optional[str] = None | |
| class StepResult(BaseModel): | |
| observation: Observation | |
| reward: float | |
| done: bool | |
| info: dict = {} | |
| import os | |
| # Load dataset (generated by curate_dataset.py) | |
| DATASET_PATH = os.path.join(os.path.dirname(__file__), "data", "emails.json") | |
| try: | |
| with open(DATASET_PATH, "r", encoding="utf-8") as f: | |
| _dataset = json.load(f) | |
| except FileNotFoundError: | |
| # Fallback to empty if not curated yet, though curate_dataset.py should be run first | |
| _dataset = {"task1": {"emails": [], "ground_truth": {}}, "task2": {"emails": []}, "task3": {"emails": [], "ground_truth": {}, "urgent_ids": [], "archive_ids": [], "flag_ids": []}} | |
| TASK1_EMAILS = _dataset["task1"]["emails"] | |
| TASK1_GROUND_TRUTH = _dataset["task1"].get("ground_truth", {}) | |
| TASK2_EMAIL = _dataset["task2"]["emails"][0] if _dataset["task2"]["emails"] else {} | |
| TASK3_EMAILS = _dataset["task3"]["emails"] | |
| TASK3_GROUND_TRUTH = _dataset["task3"].get("ground_truth", {}) | |
| TASK3_URGENT_IDS = set(_dataset["task3"].get("urgent_ids", [])) | |
| TASK3_ARCHIVE_IDS = set(_dataset["task3"].get("archive_ids", [])) | |
| TASK3_FLAG_IDS = set(_dataset["task3"].get("flag_ids", [])) | |
| def grade_task1(state: InboxState) -> float: | |
| score = 0.0 | |
| for email in state.inbox: | |
| gt = TASK1_GROUND_TRUTH.get(email.id) | |
| if gt and "urgent" in email.labels and gt == "urgent": | |
| score += 0.2 | |
| elif gt and "normal" in email.labels and gt == "normal": | |
| score += 0.2 | |
| elif gt and "low" in email.labels and gt == "low": | |
| score += 0.2 | |
| return round(min(score, 1.0), 2) | |
| def grade_task2(state: InboxState) -> float: | |
| score = 0.0 | |
| email = next((e for e in state.inbox if e.id == "t2_001"), None) | |
| if email is None or not email.replied or not email.reply_body: | |
| return 0.0 | |
| reply = email.reply_body.lower() | |
| issues_covered = 0 | |
| if "order" in reply and ("48291" in reply or "order" in reply): | |
| issues_covered += 1 | |
| if any(w in reply for w in ["refund", "deliver", "shipment", "track"]): | |
| issues_covered += 1 | |
| if any(w in reply for w in ["compensat", "apologi", "sorry", "inconvenien"]): | |
| issues_covered += 1 | |
| score += 0.1 * issues_covered # up to 0.3 | |
| # +0.3 professional tone | |
| professional_signals = ["dear", "sincerely", "regards", "thank you", "we apologize", | |
| "we understand", "please", "we will"] | |
| rude_signals = ["whatever", "not our fault", "calm down"] | |
| tone_score = sum(1 for w in professional_signals if w in reply) | |
| rude_penalty = sum(1 for w in rude_signals if w in reply) | |
| score += min(0.3, tone_score * 0.05) - (rude_penalty * 0.1) | |
| # +0.2 correct recipient / subject handling | |
| if email.reply_body and len(email.reply_body) > 50: | |
| score += 0.2 | |
| # +0.2 no fabricated facts (heuristic: no invented order dates / amounts) | |
| fabrication_signals = ["$", "€", "refund amount", "exact date", "tracking number is"] | |
| fab_hits = sum(1 for w in fabrication_signals if w in reply) | |
| if fab_hits == 0: | |
| score += 0.2 | |
| return round(max(0.0, min(score, 1.0)), 2) | |
| def grade_task3(state: InboxState, penalties: dict) -> float: | |
| score = 0.0 | |
| email_map = {e.id: e for e in state.inbox} | |
| # Priority labels (0.2 per correct, 10 emails = max 2.0 → normalise to 0.5 weight) | |
| label_score = 0.0 | |
| for eid, gt in TASK3_GROUND_TRUTH.items(): | |
| email = email_map.get(eid) | |
| if email and gt in email.labels: | |
| label_score += 0.2 | |
| score += min(label_score, 2.0) * 0.25 # normalise to 0.5 | |
| # Replies for urgent emails (max 0.4) | |
| reply_scores = [] | |
| for eid in TASK3_URGENT_IDS: | |
| email = email_map.get(eid) | |
| if email and email.replied and email.reply_body: | |
| reply_scores.append(min(len(email.reply_body) / 200, 1.0) * 0.1) | |
| score += sum(reply_scores) | |
| # Archive spam (0.05 each, max 0.1) | |
| for eid in TASK3_ARCHIVE_IDS: | |
| email = email_map.get(eid) | |
| if email and email.archived: | |
| score += 0.05 | |
| # Flag ambiguous (0.05 each) | |
| for eid in TASK3_FLAG_IDS: | |
| email = email_map.get(eid) | |
| if email and email.flagged: | |
| score += 0.05 | |
| # Penalties | |
| score -= penalties.get("destructive_actions", 0) * 0.1 | |
| score -= penalties.get("loop_actions", 0) * 0.05 | |
| return round(max(0.0, min(score, 1.0)), 2) | |
| # --------------------------------------------------------------------------- | |
| # Environment Class | |
| # --------------------------------------------------------------------------- | |
| class EmailTriageEnv: | |
| """OpenEnv-compatible Email Triage environment.""" | |
| TASKS = {1, 2, 3} | |
| def __init__(self, task: int = 1): | |
| assert task in self.TASKS, f"task must be one of {self.TASKS}" | |
| self.task = task | |
| self._state: Optional[InboxState] = None | |
| self._penalties = {"destructive_actions": 0, "loop_actions": 0} | |
| self._action_history: list[str] = [] | |
| self._done = False | |
| # ------------------------------------------------------------------ | |
| # OpenEnv interface | |
| # ------------------------------------------------------------------ | |
| def reset(self) -> Observation: | |
| self._penalties = {"destructive_actions": 0, "loop_actions": 0} | |
| self._action_history = [] | |
| self._done = False | |
| if self.task == 1: | |
| emails = [Email.model_validate(e) for e in TASK1_EMAILS] | |
| elif self.task == 2: | |
| emails = [Email.model_validate(TASK2_EMAIL)] | |
| else: | |
| emails = [Email.model_validate(e) for e in TASK3_EMAILS] | |
| self._state = InboxState(inbox=emails) | |
| return Observation( | |
| status="ok", | |
| message=f"Task {self.task} environment reset. Inbox contains {len(emails)} email(s).", | |
| data={"task": self.task, "inbox_size": len(emails)}, | |
| step_count=0, | |
| ) | |
| def state(self) -> dict: | |
| assert self._state is not None, "Call reset() first." | |
| return json.loads(self._state.model_dump_json(by_alias=True)) | |
| def step(self, action: Action) -> StepResult: | |
| assert self._state is not None, "Call reset() first." | |
| if self._done: | |
| return StepResult( | |
| observation=Observation(status="done", message="Episode already finished.", step_count=self._state.step_count), | |
| reward=0.0, | |
| done=True, | |
| ) | |
| self._state.step_count += 1 | |
| action_key = f"{action.action}:{action.email_id}" | |
| # Loop detection | |
| if self._action_history.count(action_key) >= 2: | |
| self._penalties["loop_actions"] += 1 | |
| self._action_history.append(action_key) | |
| obs, reward = self._dispatch(action) | |
| obs.step_count = self._state.step_count | |
| return StepResult(observation=obs, reward=reward, done=self._done) | |
| def score(self) -> float: | |
| """Return current cumulative score (0-1).""" | |
| assert self._state is not None, "Call reset() first." | |
| if self.task == 1: | |
| return grade_task1(self._state) | |
| elif self.task == 2: | |
| return grade_task2(self._state) | |
| else: | |
| return grade_task3(self._state, self._penalties) | |
| # ------------------------------------------------------------------ | |
| # Action dispatch | |
| # ------------------------------------------------------------------ | |
| def _dispatch(self, action: Action): | |
| handlers = { | |
| "list_inbox": self._act_list_inbox, | |
| "read": self._act_read, | |
| "label": self._act_label, | |
| "draft_reply":self._act_draft_reply, | |
| "archive": self._act_archive, | |
| "flag": self._act_flag, | |
| } | |
| handler = handlers.get(action.action) | |
| if handler is None: | |
| return Observation(status="error", message=f"Unknown action: {action.action}"), 0.0 | |
| return handler(action) | |
| def _act_list_inbox(self, action: Action): | |
| summaries = [ | |
| {"id": e.id, "from": e.from_, "subject": e.subject, | |
| "labels": e.labels, "replied": e.replied, "archived": e.archived, "flagged": e.flagged} | |
| for e in self._state.inbox | |
| ] | |
| return Observation(status="ok", message="Inbox listed.", data={"emails": summaries}), 0.0 | |
| def _act_read(self, action: Action): | |
| email = self._find(action.email_id) | |
| if email is None: | |
| return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0 | |
| return Observation( | |
| status="ok", | |
| message=f"Read email {action.email_id}.", | |
| data=json.loads(email.model_dump_json(by_alias=True)), | |
| ), 0.0 | |
| def _act_label(self, action: Action): | |
| email = self._find(action.email_id) | |
| if email is None: | |
| return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0 | |
| if action.priority not in ("urgent", "normal", "low"): | |
| return Observation(status="error", message="priority must be urgent | normal | low"), 0.0 | |
| # Remove existing priority labels then add new | |
| email.labels = [l for l in email.labels if l not in ("urgent", "normal", "low")] | |
| email.labels.append(action.priority) | |
| incremental = self._incremental_label_reward(email.id, action.priority) | |
| return Observation( | |
| status="ok", | |
| message=f"Labelled {action.email_id} as {action.priority}.", | |
| data={"email_id": action.email_id, "priority": action.priority}, | |
| ), incremental | |
| def _act_draft_reply(self, action: Action): | |
| email = self._find(action.email_id) | |
| if email is None: | |
| return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0 | |
| if not action.body or len(action.body.strip()) < 10: | |
| return Observation(status="error", message="Reply body too short."), 0.0 | |
| email.replied = True | |
| email.reply_body = action.body | |
| self._state.sent.append({"to": email.from_, "subject": f"Re: {email.subject}", "body": action.body}) | |
| return Observation(status="ok", message=f"Reply drafted for {action.email_id}."), 0.0 | |
| def _act_archive(self, action: Action): | |
| email = self._find(action.email_id) | |
| if email is None: | |
| return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0 | |
| # Penalty if archiving urgent email | |
| if "urgent" in email.labels: | |
| self._penalties["destructive_actions"] += 1 | |
| return Observation( | |
| status="warning", | |
| message=f"Archived urgent email {action.email_id} — penalty applied.", | |
| ), -0.1 | |
| email.archived = True | |
| return Observation(status="ok", message=f"Email {action.email_id} archived."), 0.0 | |
| def _act_flag(self, action: Action): | |
| email = self._find(action.email_id) | |
| if email is None: | |
| return Observation(status="error", message=f"Email {action.email_id} not found."), 0.0 | |
| email.flagged = True | |
| email.flag_reason = action.reason or "unspecified" | |
| return Observation(status="ok", message=f"Email {action.email_id} flagged for human review."), 0.0 | |
| # ------------------------------------------------------------------ | |
| # Helpers | |
| # ------------------------------------------------------------------ | |
| def _find(self, email_id: Optional[str]) -> Optional[Email]: | |
| if email_id is None: | |
| return None | |
| return next((e for e in self._state.inbox if e.id == email_id), None) | |
| def _incremental_label_reward(self, email_id: str, priority: str) -> float: | |
| """Return +0.2 if label matches ground truth for task 1.""" | |
| if self.task == 1: | |
| gt = TASK1_GROUND_TRUTH.get(email_id) | |
| return 0.2 if gt == priority else 0.0 | |
| return 0.0 | |