Spaces:
Sleeping
Sleeping
| from typing import Dict, Any, Tuple | |
| from .models import State, Action, Observation | |
| from .tasks import get_initial_state | |
| from .graders import grade_action | |
| class EmailTriageEnv: | |
| def __init__(self): | |
| self._state: State = None | |
| self.current_task: str = "easy" | |
| async def reset(self, task: str = "easy") -> Tuple[Observation, Dict[str, Any]]: | |
| self.current_task = task | |
| self._state = get_initial_state(task) | |
| return self._state.observation, {} | |
| async def step(self, action_dict: dict) -> Tuple[Observation, float, bool, Dict[str, Any]]: | |
| if self._state is None or self._state.is_done: | |
| obs = self._state.observation if self._state else None | |
| return obs, 0.0, True, {"error": "Environment must be reset before stepping"} | |
| try: | |
| action = Action(**action_dict) | |
| except Exception as e: | |
| self._state.step_count += 1 | |
| if self._state.step_count >= self._state.max_steps: | |
| self._state.is_done = True | |
| return self._state.observation, 0.0, self._state.is_done, {"error": f"Invalid action format: {str(e)}"} | |
| self._state.step_count += 1 | |
| email_to_process = None | |
| for i, email in enumerate(self._state.observation.inbox): | |
| if email.id == action.email_id: | |
| email_to_process = self._state.observation.inbox.pop(i) | |
| break | |
| if not email_to_process: | |
| self._state.is_done = len(self._state.observation.inbox) == 0 or self._state.step_count >= self._state.max_steps | |
| return self._state.observation, 0.0, self._state.is_done, {"error": "Email ID not found in inbox"} | |
| reward = grade_action(self.current_task, action, email_to_process, self._state) | |
| reward = max(0.0, min(1.0, reward)) | |
| self._state.score = max(0.0, min(1.0, self._state.score + reward)) | |
| if action.action_type == "reply": | |
| self._state.observation.replied.append(email_to_process) | |
| elif action.action_type == "forward": | |
| self._state.observation.forwarded.append(email_to_process) | |
| elif action.action_type == "archive": | |
| self._state.observation.archived.append(email_to_process) | |
| elif action.action_type == "mark_spam": | |
| self._state.observation.spam.append(email_to_process) | |
| elif action.action_type == "request_info": | |
| self._state.observation.pending_info.append(email_to_process) | |
| elif action.action_type == "escalate": | |
| self._state.observation.escalated.append(email_to_process) | |
| if len(self._state.observation.inbox) == 0 or self._state.step_count >= self._state.max_steps: | |
| self._state.is_done = True | |
| return self._state.observation, reward, self._state.is_done, {} | |
| def state(self) -> State: | |
| if self._state is None: | |
| self._state = get_initial_state("easy") | |
| return self._state | |