Spaces:
Sleeping
Sleeping
| """ | |
| Email Triage Environment β core RL environment logic. | |
| Implements the OpenEnv interface: | |
| reset(task_id) β EmailObservation | |
| step(action) β StepResult | |
| state() β EmailState | |
| Three tasks of increasing difficulty: | |
| Task 1 (easy) β Spam Detection : label only (spam | inbox) | |
| Task 2 (medium) β Priority Triage : label + priority | |
| Task 3 (hard) β Full Email Triage : label + priority + category | |
| """ | |
| import random | |
| import uuid | |
| from typing import Dict, List, Optional | |
| from emails import EMAILS | |
| from models import ( | |
| EmailAction, | |
| EmailObservation, | |
| EmailReward, | |
| EmailState, | |
| StepResult, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Constants | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TASK_META = { | |
| 1: {"name": "Spam Detection", "difficulty": "easy"}, | |
| 2: {"name": "Priority Triage", "difficulty": "medium"}, | |
| 3: {"name": "Full Email Triage", "difficulty": "hard"}, | |
| } | |
| # Numeric mapping for priority comparison (to give partial credit) | |
| PRIORITY_RANK = {"high": 2, "medium": 1, "low": 0} | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Environment | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EmailTriageEnvironment: | |
| """ | |
| Stateful RL environment for email triage. | |
| An episode consists of triaging all 20 emails in the dataset (shuffled). | |
| The agent receives one email per step and must classify it according to | |
| the active task's requirements. | |
| Usage: | |
| env = EmailTriageEnvironment() | |
| obs = env.reset(task_id=1) | |
| while not obs.done: | |
| action = EmailAction(label="spam") | |
| result = env.step(action) | |
| obs = result.observation | |
| score = env.get_episode_score() | |
| """ | |
| def __init__(self) -> None: | |
| self.episode_id: Optional[str] = None | |
| self.task_id: int = 1 | |
| self.emails: List[dict] = [] | |
| self.current_index: int = 0 | |
| self.step_count: int = 0 | |
| self.cumulative_reward: float = 0.0 | |
| self.done: bool = True | |
| self.results: List[dict] = [] # history β used by grader | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PUBLIC API (OpenEnv spec) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, task_id: int = 1) -> EmailObservation: | |
| """Start a new episode. Shuffles emails and returns the first one.""" | |
| if task_id not in TASK_META: | |
| raise ValueError(f"task_id must be 1, 2, or 3 β got {task_id}") | |
| self.episode_id = str(uuid.uuid4()) | |
| self.task_id = task_id | |
| self.emails = EMAILS.copy() | |
| random.shuffle(self.emails) | |
| self.current_index = 0 | |
| self.step_count = 0 | |
| self.cumulative_reward = 0.0 | |
| self.done = False | |
| self.results = [] | |
| return self._make_observation(reward=0.0) | |
| def step(self, action: EmailAction) -> StepResult: | |
| """ | |
| Submit a classification action for the current email. | |
| Returns the next email as an observation, plus the reward earned. | |
| """ | |
| if self.done: | |
| raise ValueError( | |
| "Episode is already done. Call reset() to start a new episode." | |
| ) | |
| if self.current_index >= len(self.emails): | |
| raise ValueError("No more emails β episode should have ended.") | |
| current_email = self.emails[self.current_index] | |
| # Score the action | |
| reward_info = self._calculate_reward(action, current_email) | |
| # Record result for grader | |
| self.results.append( | |
| { | |
| "email_id": current_email["id"], | |
| "subject": current_email["subject"], | |
| "action": action.model_dump(), | |
| "ground_truth": current_email["ground_truth"], | |
| "reward": reward_info.value, | |
| "feedback": reward_info.feedback, | |
| "penalties": reward_info.penalties, | |
| } | |
| ) | |
| # Advance state | |
| self.cumulative_reward += reward_info.value | |
| self.step_count += 1 | |
| self.current_index += 1 | |
| if self.current_index >= len(self.emails): | |
| self.done = True | |
| obs = self._make_observation(reward=reward_info.value) | |
| return StepResult( | |
| observation=obs, | |
| reward=reward_info.value, | |
| done=self.done, | |
| info={ | |
| "label_score": reward_info.label_score, | |
| "priority_score": reward_info.priority_score, | |
| "category_score": reward_info.category_score, | |
| "feedback": reward_info.feedback, | |
| "penalties": reward_info.penalties, | |
| "ground_truth": current_email["ground_truth"], | |
| }, | |
| ) | |
| def state(self) -> EmailState: | |
| """Return current episode metadata without email content.""" | |
| n = len(self.emails) if self.emails else len(EMAILS) | |
| score = self.cumulative_reward / self.step_count if self.step_count > 0 else 0.0 | |
| meta = TASK_META.get(self.task_id, {"name": "Unknown", "difficulty": "unknown"}) | |
| return EmailState( | |
| episode_id=self.episode_id or "", | |
| task_id=self.task_id, | |
| task_name=meta["name"], | |
| task_difficulty=meta["difficulty"], | |
| step_count=self.step_count, | |
| total_emails=n, | |
| cumulative_reward=round(self.cumulative_reward, 4), | |
| score=round(max(0.0, min(1.0, score)), 4), | |
| done=self.done, | |
| ) | |
| def get_episode_score(self) -> float: | |
| """ | |
| Final normalized score for the completed episode [0.0, 1.0]. | |
| Called by the grader after the episode is done. | |
| """ | |
| if not self.results: | |
| return 0.0 | |
| avg = sum(r["reward"] for r in self.results) / len(self.results) | |
| return round(max(0.0, min(1.0, avg)), 4) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PRIVATE HELPERS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _make_observation(self, reward: float) -> EmailObservation: | |
| """Build an observation from current environment state.""" | |
| # Terminal observation β no email content | |
| if self.done or self.current_index >= len(self.emails): | |
| return EmailObservation( | |
| email_id="EPISODE_DONE", | |
| subject="", | |
| sender="", | |
| body="All emails have been processed. Call /reset to start a new episode.", | |
| timestamp="", | |
| step=self.step_count, | |
| total_emails=len(self.emails), | |
| emails_remaining=0, | |
| reward=reward, | |
| cumulative_reward=round(self.cumulative_reward, 4), | |
| done=True, | |
| ) | |
| email = self.emails[self.current_index] | |
| return EmailObservation( | |
| email_id=email["id"], | |
| subject=email["subject"], | |
| sender=email["sender"], | |
| body=email["body"], | |
| timestamp=email["timestamp"], | |
| step=self.step_count, | |
| total_emails=len(self.emails), | |
| emails_remaining=len(self.emails) - self.current_index, | |
| reward=reward, | |
| cumulative_reward=round(self.cumulative_reward, 4), | |
| done=False, | |
| ) | |
| def _calculate_reward(self, action: EmailAction, email: dict) -> EmailReward: | |
| """ | |
| Score an action against ground truth for the current task. | |
| Reward structure (max 1.0 per email): | |
| Task 1: 1.0 for correct spam/not-spam classification | |
| Task 2: 0.5 label + 0.5 priority, penalty for missed urgents | |
| Task 3: 0.35 label + 0.35 priority + 0.30 category, | |
| penalty for missed urgents and false-positive spam | |
| """ | |
| gt = email["ground_truth"] | |
| label_score = 0.0 | |
| priority_score = 0.0 | |
| category_score = 0.0 | |
| feedback_parts: List[str] = [] | |
| penalties: List[str] = [] | |
| # ββ TASK 1: Spam Detection βββββββββββββββββββββββββββββββββ | |
| if self.task_id == 1: | |
| pred_spam = action.label == "spam" | |
| true_spam = gt["label"] == "spam" | |
| if pred_spam == true_spam: | |
| label_score = 1.0 | |
| feedback_parts.append("β Correct spam classification") | |
| else: | |
| label_score = 0.0 | |
| if true_spam: | |
| feedback_parts.append( | |
| f"β Missed spam (got '{action.label}', should be 'spam')" | |
| ) | |
| else: | |
| feedback_parts.append( | |
| f"β False positive β legitimate email labelled spam (got 'spam', should be 'inbox')" | |
| ) | |
| total = label_score | |
| # ββ TASK 2: Priority Triage ββββββββββββββββββββββββββββββββ | |
| elif self.task_id == 2: | |
| # Label (0.5 weight) | |
| if action.label == gt["label"]: | |
| label_score = 1.0 | |
| feedback_parts.append("β Correct label") | |
| elif action.label in ("inbox", "urgent") and gt["label"] in ("inbox", "urgent"): | |
| # Partial: agent identified email as important, just confused urgency level | |
| label_score = 0.5 | |
| feedback_parts.append( | |
| f"~ Close label (got '{action.label}', expected '{gt['label']}')" | |
| ) | |
| else: | |
| label_score = 0.0 | |
| feedback_parts.append( | |
| f"β Wrong label (got '{action.label}', expected '{gt['label']}')" | |
| ) | |
| # Priority (0.5 weight) β partial credit for adjacent levels | |
| pred_p = PRIORITY_RANK.get(action.priority, 1) | |
| true_p = PRIORITY_RANK.get(gt["priority"], 1) | |
| if pred_p == true_p: | |
| priority_score = 1.0 | |
| feedback_parts.append("β Correct priority") | |
| elif abs(pred_p - true_p) == 1: | |
| priority_score = 0.5 | |
| feedback_parts.append( | |
| f"~ Close priority (got '{action.priority}', expected '{gt['priority']}')" | |
| ) | |
| else: | |
| priority_score = 0.0 | |
| feedback_parts.append( | |
| f"β Wrong priority (got '{action.priority}', expected '{gt['priority']}')" | |
| ) | |
| # Critical penalty: urgent email assigned low priority | |
| if gt["label"] == "urgent" and action.priority == "low": | |
| penalties.append( | |
| "MISSED_URGENT: Urgent email assigned low priority (β0.30)" | |
| ) | |
| penalty = 0.30 if penalties else 0.0 | |
| total = (label_score * 0.5 + priority_score * 0.5) - penalty | |
| # ββ TASK 3: Full Email Triage ββββββββββββββββββββββββββββββ | |
| elif self.task_id == 3: | |
| # Label (0.35 weight) | |
| if action.label == gt["label"]: | |
| label_score = 1.0 | |
| feedback_parts.append("β Correct label") | |
| elif action.label in ("inbox", "urgent") and gt["label"] in ("inbox", "urgent"): | |
| label_score = 0.5 | |
| feedback_parts.append( | |
| f"~ Close label (got '{action.label}', expected '{gt['label']}')" | |
| ) | |
| else: | |
| label_score = 0.0 | |
| feedback_parts.append( | |
| f"β Wrong label (got '{action.label}', expected '{gt['label']}')" | |
| ) | |
| # Priority (0.35 weight) | |
| pred_p = PRIORITY_RANK.get(action.priority, 1) | |
| true_p = PRIORITY_RANK.get(gt["priority"], 1) | |
| if pred_p == true_p: | |
| priority_score = 1.0 | |
| feedback_parts.append("β Correct priority") | |
| elif abs(pred_p - true_p) == 1: | |
| priority_score = 0.5 | |
| feedback_parts.append( | |
| f"~ Close priority (got '{action.priority}', expected '{gt['priority']}')" | |
| ) | |
| else: | |
| priority_score = 0.0 | |
| feedback_parts.append( | |
| f"β Wrong priority (got '{action.priority}', expected '{gt['priority']}')" | |
| ) | |
| # Category (0.30 weight) | |
| if action.category == gt["category"]: | |
| category_score = 1.0 | |
| feedback_parts.append("β Correct category") | |
| else: | |
| category_score = 0.0 | |
| feedback_parts.append( | |
| f"β Wrong category (got '{action.category}', expected '{gt['category']}')" | |
| ) | |
| # Penalty: urgent email not flagged as high priority | |
| if gt["label"] == "urgent" and action.priority != "high": | |
| penalties.append( | |
| "MISSED_URGENT: Critical email not assigned high priority (β0.30)" | |
| ) | |
| # Penalty: legitimate email falsely marked as spam | |
| if gt["label"] != "spam" and action.label == "spam": | |
| penalties.append( | |
| "FALSE_SPAM: Legitimate email marked as spam (β0.20)" | |
| ) | |
| penalty = sum( | |
| 0.30 if "MISSED_URGENT" in p else 0.20 for p in penalties | |
| ) | |
| total = ( | |
| label_score * 0.35 | |
| + priority_score * 0.35 | |
| + category_score * 0.30 | |
| ) - penalty | |
| else: | |
| total = 0.0 | |
| total = round(max(0.0, min(1.0, total)), 4) | |
| return EmailReward( | |
| value=total, | |
| label_score=label_score, | |
| priority_score=priority_score, | |
| category_score=category_score, | |
| feedback=" | ".join(feedback_parts), | |
| penalties=penalties, | |
| ) | |