Spaces:
Sleeping
Sleeping
| from typing import Dict, Any, List, Optional, Tuple | |
| from models import Action, ActionType, Email, Observation, Reward | |
| from .data import TASK_CONFIGS | |
| import nltk | |
| import random | |
| nltk.download("vader_lexicon", quiet=True) | |
| nltk.download("punkt_tab", quiet=True) | |
| from nltk.stem import PorterStemmer | |
| from nltk.sentiment.vader import SentimentIntensityAnalyzer | |
| from nltk.tokenize import word_tokenize | |
| stemmer = PorterStemmer() | |
| vader = SentimentIntensityAnalyzer() | |
| class EmailSortingEnvironment: | |
| def __init__(self) -> None: | |
| self.task_id: str = "" | |
| self.task_config: Dict = {} | |
| self.email_queue: List[Dict] = [] | |
| self.current_email_idx: int = 0 | |
| self.step_count: int = 0 | |
| self.processed_emails: List[Dict] = [] | |
| self.episode_actions: List[Dict] = [] | |
| self.done: bool = False | |
| self.last_grader_score: Optional[float] = None | |
| def state(self) -> Dict[str, Any]: | |
| return { | |
| "task_id": self.task_id, | |
| "step_count": self.step_count, | |
| "processed_count": len(self.processed_emails), | |
| "queue_size": max(0, len(self.email_queue) - self.current_email_idx), | |
| "done": self.done, | |
| "max_steps": ( | |
| self.task_config.get("max_steps", 0) if self.task_config else 0 | |
| ), | |
| "episode_actions": self.episode_actions, | |
| "last_grader_score": self.last_grader_score, | |
| } | |
| def get_observation(self) -> Observation: | |
| if not self.task_id: | |
| return Observation() | |
| task = self.task_config | |
| available_actions = ["classify", "respond", "escalate", "archive", "skip"] | |
| if self.task_id == "email_classification": | |
| available_actions = ["classify"] | |
| elif self.task_id == "response_drafting": | |
| available_actions = ["respond"] | |
| remaining = self.email_queue[self.current_email_idx :] | |
| def to_email(e: Dict) -> Email: | |
| return Email( | |
| id=e["id"], | |
| subject=e["subject"], | |
| body=e["body"], | |
| sender=e["sender"], | |
| sender_tier=e["sender_tier"], | |
| received_minutes_ago=e["received_minutes_ago"], | |
| ) | |
| if self.task_id == "support_session": | |
| queue = [to_email(e) for e in remaining] | |
| current = queue[0] if queue else None | |
| else: | |
| current = to_email(remaining[0]) if remaining else None | |
| queue = [] | |
| return Observation( | |
| current_email=current, | |
| email_queue=queue, | |
| processed_count=len(self.processed_emails), | |
| step_count=self.step_count, | |
| task_id=self.task_id, | |
| task_description=task.get("description", ""), | |
| available_actions=available_actions, | |
| context={ | |
| "max_steps": task.get("max_steps", 0), | |
| "remaining_steps": task.get("max_steps", 0) - self.step_count, | |
| "queue_size": len(remaining), | |
| }, | |
| ) | |
| def process_email_classification(self, action: Action) -> Tuple[Reward, Dict]: | |
| if self.current_email_idx >= len(self.email_queue): | |
| return Reward(value=0.0, reason="Queue is empty"), {"error": "queue_empty"} | |
| if action.action_type != ActionType.CLASSIFY: | |
| return ( | |
| Reward( | |
| value=-0.05, | |
| components={"wrong_action": -0.05}, | |
| reason="Must use classify action", | |
| ), | |
| {}, | |
| ) | |
| email = self.email_queue[self.current_email_idx] | |
| components: Dict[str, float] = {} | |
| total = 0.0 | |
| cat_given = action.category.value if action.category else None | |
| urg_given = action.urgency.value if action.urgency else None | |
| if cat_given == email.get("correct_category"): | |
| components["category_correct"] = 0.15 | |
| total += 0.15 | |
| else: | |
| components["category_wrong"] = -0.05 | |
| total -= 0.05 | |
| if urg_given == email.get("correct_urgency"): | |
| components["urgency_correct"] = 0.05 | |
| total += 0.05 | |
| else: | |
| components["urgency_wrong"] = -0.02 | |
| total -= 0.02 | |
| self.episode_actions.append( | |
| { | |
| "email_id": email["id"], | |
| "action_type": "classify", | |
| "category": cat_given, | |
| "urgency": urg_given, | |
| "correct_category": email.get("correct_category"), | |
| "correct_urgency": email.get("correct_urgency"), | |
| } | |
| ) | |
| self.processed_emails.append(email) | |
| self.current_email_idx += 1 | |
| cat_ok = "correct" if components.get("category_correct") else "wrong" | |
| urg_ok = "correct" if components.get("urgency_correct") else "wrong" | |
| return Reward( | |
| value=round(total, 4), | |
| components=components, | |
| reason=f"{email['id']}: category={cat_ok}, urgency={urg_ok}", | |
| ), {"email_id": email["id"]} | |
| def process_response_drafting(self, action: Action) -> Tuple[Reward, Dict]: | |
| if self.current_email_idx >= len(self.email_queue): | |
| return Reward(value=0.0, reason="Queue empty"), {"error": "queue_empty"} | |
| if action.action_type != ActionType.RESPOND: | |
| return ( | |
| Reward( | |
| value=-0.05, | |
| components={"wrong_action": -0.05}, | |
| reason="Must use respond action", | |
| ), | |
| {}, | |
| ) | |
| email = self.email_queue[self.current_email_idx] | |
| response = (action.response_text or "").strip() | |
| response_lower = response.lower() | |
| components: Dict[str, float] = {} | |
| total = 0.0 | |
| if len(response) < 50: | |
| components["too_short"] = -0.1 | |
| total -= 0.1 | |
| else: | |
| components["adequate_length"] = 0.05 | |
| total += 0.05 | |
| required = email.get("required_keywords", []) | |
| min_matches = email.get("min_keyword_matches", 1) | |
| response_tokens = word_tokenize(response_lower) | |
| response_stems = {stemmer.stem(t) for t in response_tokens} | |
| matched = [kw for kw in required if stemmer.stem(kw.lower()) in response_stems] | |
| kw_score = round(min(len(matched) / max(min_matches, 1), 1.0) * 0.25, 4) | |
| total += kw_score | |
| vader_scores = vader.polarity_scores(response) | |
| if vader_scores["neg"] > 0.4: | |
| components["unprofessional"] = -0.1 | |
| total -= 0.1 | |
| self.episode_actions.append( | |
| { | |
| "email_id": email["id"], | |
| "action_type": "respond", | |
| "response_length": len(response), | |
| "keywords_matched": matched, | |
| "keywords_required": required, | |
| } | |
| ) | |
| self.processed_emails.append(email) | |
| self.current_email_idx += 1 | |
| return Reward( | |
| value=round(total, 4), | |
| components=components, | |
| reason=f"{email['id']}: {len(matched)}/{min_matches} keywords matched", | |
| ), {"email_id": email["id"], "keywords_matched": matched} | |
| def process_support_session(self, action: Action) -> Tuple[Reward, Dict]: | |
| remaining = self.email_queue[self.current_email_idx :] | |
| if not remaining: | |
| return Reward(value=0.0, reason="Queue empty"), {"error": "queue_empty"} | |
| target_idx = self.current_email_idx | |
| if action.email_id: | |
| for i, e in enumerate(remaining): | |
| if e["id"] == action.email_id: | |
| target_idx = self.current_email_idx + i | |
| break | |
| if target_idx != self.current_email_idx: | |
| self.email_queue[self.current_email_idx], self.email_queue[target_idx] = ( | |
| self.email_queue[target_idx], | |
| self.email_queue[self.current_email_idx], | |
| ) | |
| email = self.email_queue[self.current_email_idx] | |
| position = len(self.processed_emails) | |
| vip_check = email.get("sender_tier") == "vip" | |
| expected_urgency = email.get("correct_urgency", "low") | |
| components: Dict[str, float] = {} | |
| total = 0.0 | |
| if vip_check and position < 4: | |
| components["vip_priority"] = 0.08 | |
| total += 0.08 | |
| elif vip_check and position >= 4: | |
| components["vip_delayed"] = -0.05 | |
| total -= 0.05 | |
| elif expected_urgency == "high" and position < 6: | |
| components["high_priority"] = 0.05 | |
| total += 0.05 | |
| elif expected_urgency == "low" and position > 6: | |
| components["low_priority"] = 0.03 | |
| total += 0.03 | |
| cat_given = action.category.value if action.category else None | |
| urg_given = action.urgency.value if action.urgency else None | |
| if cat_given == email.get("correct_category"): | |
| components["category_correct"] = 0.04 | |
| total += 0.04 | |
| if urg_given == email.get("correct_urgency"): | |
| components["urgency_correct"] = 0.02 | |
| total += 0.02 | |
| action_type = action.action_type.value | |
| correct_action = email.get("correct_action", "respond") | |
| if action_type == correct_action or ( | |
| email.get("escalation_required") and action_type == "escalate" | |
| ): | |
| components["correct_action"] = 0.06 | |
| total += 0.06 | |
| elif action_type in ("respond", "escalate", "archive"): | |
| components["wrong_action"] = -0.03 | |
| total -= 0.03 | |
| if ( | |
| action_type == "respond" | |
| and action.response_text | |
| and len(action.response_text) > 50 | |
| ): | |
| components["response_present"] = 0.02 | |
| total += 0.02 | |
| if email.get("correct_category") == "spam" and action_type != "archive": | |
| components["spam_not_archived"] = -0.04 | |
| total -= 0.04 | |
| self.episode_actions.append( | |
| { | |
| "email_id": email["id"], | |
| "action_type": action_type, | |
| "category": cat_given, | |
| "urgency": urg_given, | |
| "correct_category": email.get("correct_category"), | |
| "correct_urgency": email.get("correct_urgency"), | |
| "correct_action": correct_action, | |
| "position": position, | |
| } | |
| ) | |
| self.processed_emails.append(email) | |
| self.current_email_idx += 1 | |
| return Reward( | |
| value=round(total, 4), | |
| components=components, | |
| reason=f"{email['id']} at position {position}: action={action_type}", | |
| ), {"email_id": email["id"]} | |
| def process_action(self, action: Action) -> Tuple[Reward, Dict]: | |
| if self.task_id == "email_classification": | |
| return self.process_email_classification(action) | |
| if self.task_id == "response_drafting": | |
| return self.process_response_drafting(action) | |
| if self.task_id == "support_session": | |
| return self.process_support_session(action) | |
| return Reward(value=0.0, reason="Unknown task"), {} | |
| def email_classification_score(self) -> float: | |
| total = len(TASK_CONFIGS["email_classification"]["emails"]) | |
| cat_correct = sum( | |
| 1 | |
| for a in self.episode_actions | |
| if a.get("category") == a.get("correct_category") | |
| ) | |
| urg_correct = sum( | |
| 1 | |
| for a in self.episode_actions | |
| if a.get("urgency") == a.get("correct_urgency") | |
| ) | |
| return round(0.7 * cat_correct / total + 0.3 * urg_correct / total, 3) | |
| def response_drafting_score(self) -> float: | |
| emails = TASK_CONFIGS["response_drafting"]["emails"] | |
| total = len(emails) | |
| email_map = {e["id"]: e for e in emails} | |
| score = 0.0 | |
| for act in self.episode_actions: | |
| cfg = email_map.get(act.get("email_id")) | |
| if not cfg: | |
| continue | |
| matched = act.get("keywords_matched", []) | |
| min_m = cfg["min_keyword_matches"] | |
| length = act.get("response_length", 0) | |
| kw = min(len(matched) / max(min_m, 1), 1.0) | |
| length_bonus = min(length / 200, 0.2) if length > 50 else 0.0 | |
| score += kw * 0.8 + length_bonus | |
| return round(score / total, 3) | |
| def support_session_score(self) -> float: | |
| emails_by_id = {e["id"]: e for e in TASK_CONFIGS["support_session"]["emails"]} | |
| vip_ids = { | |
| e["id"] | |
| for e in TASK_CONFIGS["support_session"]["emails"] | |
| if e.get("sender_tier") == "vip" | |
| } | |
| high_ids = { | |
| e["id"] | |
| for e in TASK_CONFIGS["support_session"]["emails"] | |
| if e.get("correct_urgency") == "high" and e.get("sender_tier") != "vip" | |
| } | |
| order = [a["email_id"] for a in self.episode_actions] | |
| total_emails = len(TASK_CONFIGS["support_session"]["emails"]) | |
| vip_weight = 0.20 / max(len(vip_ids), 1) | |
| high_weight = 0.10 / max(len(high_ids), 1) | |
| priority = 0.0 | |
| for eid in vip_ids: | |
| if eid in order: | |
| pos = order.index(eid) | |
| priority += vip_weight if pos < 4 else vip_weight * 0.4 | |
| for eid in high_ids: | |
| if eid in order: | |
| pos = order.index(eid) | |
| priority += high_weight if pos < 6 else high_weight * 0.4 | |
| n = len(self.episode_actions) | |
| cat_ok = sum( | |
| 1 | |
| for a in self.episode_actions | |
| if a.get("category") | |
| == emails_by_id.get(a["email_id"], {}).get("correct_category") | |
| ) | |
| urg_ok = sum( | |
| 1 | |
| for a in self.episode_actions | |
| if a.get("urgency") | |
| == emails_by_id.get(a["email_id"], {}).get("correct_urgency") | |
| ) | |
| classification = cat_ok / max(n, 1) * 0.15 + urg_ok / max(n, 1) * 0.15 | |
| act_ok = sum( | |
| 1 | |
| for a in self.episode_actions | |
| if a.get("action_type") | |
| == emails_by_id.get(a["email_id"], {}).get("correct_action") | |
| ) | |
| action_score = act_ok / max(n, 1) * 0.30 | |
| coverage = (n / total_emails) * 0.10 | |
| return round(min(priority + classification + action_score + coverage, 1.0), 3) | |
| def compute_final_score(self) -> float: | |
| if not self.episode_actions: | |
| return 0.001 | |
| if self.task_id == "email_classification": | |
| score = self.email_classification_score() | |
| elif self.task_id == "response_drafting": | |
| score = self.response_drafting_score() | |
| elif self.task_id == "support_session": | |
| score = self.support_session_score() | |
| else: | |
| return 0.001 | |
| return round(max(0.001, min(0.999, score)), 3) | |
| def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict]: | |
| if self.done: | |
| return ( | |
| self.get_observation(), | |
| Reward(value=0.0, reason="Episode already done"), | |
| True, | |
| {}, | |
| ) | |
| if not self.task_id: | |
| raise RuntimeError("Environment not initialized, call reset() first.") | |
| self.step_count += 1 | |
| reward, info = self.process_action(action) | |
| max_steps = self.task_config["max_steps"] | |
| processing_status = self.current_email_idx >= len(self.email_queue) | |
| if processing_status or self.step_count >= max_steps: | |
| self.done = True | |
| self.last_grader_score = self.compute_final_score() | |
| info["final_score"] = self.last_grader_score | |
| if not self.done: | |
| reward.value = round(reward.value - 0.005, 4) | |
| reward.components["step_penalty"] = -0.005 | |
| return self.get_observation(), reward, self.done, info | |
| def reset(self, task_id: str = "email_classification") -> Observation: | |
| if task_id not in TASK_CONFIGS: | |
| raise ValueError( | |
| f"Unknown task: {task_id}. Valid: {list(TASK_CONFIGS.keys())}" | |
| ) | |
| self.task_id = task_id | |
| self.task_config = TASK_CONFIGS[task_id] | |
| shuffled_emails = random.sample(self.task_config["emails"], len(self.task_config["emails"])) | |
| self.email_queue = [dict(e) for e in shuffled_emails] | |
| self.current_email_idx = 0 | |
| self.step_count = 0 | |
| self.processed_emails = [] | |
| self.episode_actions = [] | |
| self.done = False | |
| self.last_grader_score = None | |
| return self.get_observation() | |