Spaces:
Sleeping
Sleeping
| """ | |
| OpenEnv - Production-Ready Reinforcement Learning Environment | |
| A Gymnasium-compatible environment implementing the standard step(), reset(), | |
| and state() API for AI agent training in an Email Triage Task. | |
| """ | |
| import numpy as np | |
| from typing import Tuple, Optional, Dict, Any, Union, List | |
| import gymnasium as gym | |
| from gymnasium import spaces | |
| import logging | |
| import time | |
| import random | |
| from openenv.core.config import EnvConfig | |
| from openenv.core.models import Observation, Action, Reward, Email, EnvState | |
| def _generate_email(email_id: int, task_level: str, spam_ratio: float, urgent_ratio: float, confounding_ratio: float) -> Email: | |
| """Generate a random Email object for triage based on probability ratios.""" | |
| is_spam = random.random() < spam_ratio | |
| is_urgent = False if is_spam else random.random() < urgent_ratio | |
| is_confusing = False | |
| is_internal = False | |
| if task_level in ['medium', 'hard']: | |
| is_confusing = random.random() < confounding_ratio | |
| if is_spam: | |
| sender = f"spammer{random.randint(1,999)}@shady-deals.com" | |
| subject = "You Won $1,000,000!" if not is_confusing else "Invoice #91823 Overdue" | |
| body = "Click here to claim your prize." if not is_confusing else "Please review the attached invoice urgently to avoid account suspension." | |
| is_internal = False | |
| elif is_urgent: | |
| sender = "boss@company.com" | |
| subject = "URGENT: Project deadline!" if not is_confusing else "Update?" | |
| body = "We need the final report ASAP. Forward it to the team." if not is_confusing else "Are we on track? Let me know." | |
| is_internal = True | |
| else: | |
| is_internal = random.random() < 0.8 | |
| if is_internal: | |
| sender = f"colleague_{random.randint(1,50)}@company.com" | |
| subject = "Lunch later?" if not is_confusing else "Git merge conflict" | |
| body = "I'm heading out at 12." if not is_confusing else "I think there is an issue with the latest PR, can you reply with your thoughts?" | |
| else: | |
| sender = "newsletter@techweekly.com" | |
| subject = "This week in Tech" | |
| body = "Here are the top 10 trends you need to know." | |
| # For Hard tasks, inject random noise | |
| if task_level == 'hard' and random.random() < 0.2: | |
| subject = subject.upper() if random.random() < 0.5 else subject.lower() | |
| body += "\n\n" + " ".join([chr(random.randint(97, 122)) for _ in range(20)]) | |
| return Email( | |
| id=f"email_{email_id}", | |
| sender=sender, | |
| subject=subject, | |
| body=body, | |
| is_urgent=is_urgent, | |
| is_spam=is_spam, | |
| is_internal=is_internal | |
| ) | |
| class OpenEnv(gym.Env): | |
| """ | |
| Email Triage Environment. | |
| Agent must read incoming emails and perform one of the actions: | |
| 0 = Ignore | |
| 1 = Reply | |
| 2 = Forward | |
| 3 = Archive | |
| 4 = Delete | |
| Observation Space: Continuous remaining count and discrete current_email attributes. | |
| Action Space: Discrete(5) | |
| """ | |
| metadata = { | |
| 'render_modes': ['human'], | |
| 'render_fps': 1, | |
| } | |
| def __init__( | |
| self, | |
| config: Optional[EnvConfig] = None, | |
| render_mode: Optional[str] = None, | |
| ): | |
| super().__init__() | |
| # Configuration | |
| self.config = config if config is not None else EnvConfig() | |
| self.config.validate() | |
| if render_mode is not None: | |
| self.config.render_mode = render_mode | |
| if self.config.random_seed is not None: | |
| self.seed(self.config.random_seed) | |
| self._setup_logging() | |
| # Action space: 5 discrete actions | |
| self.action_space = spaces.Discrete(5) | |
| # Simple array observation space backing the gym interface | |
| # [emails_remaining, is_spam, is_urgent, is_internal] | |
| self.observation_space = spaces.Box( | |
| low=0.0, high=float('inf'), shape=(4,), dtype=np.float32 | |
| ) | |
| # State | |
| self.emails_queue: List[Email] = [] | |
| self.current_email_index: int = 0 | |
| self.total_reward: float = 0.0 | |
| self.start_time: float = 0.0 | |
| self.metrics: Dict[str, Any] = {} | |
| self.logger.info("Email Triage OpenEnv initialized.") | |
| def _setup_logging(self) -> None: | |
| self.logger = logging.getLogger('OpenEnv') | |
| self.logger.setLevel(logging.INFO if self.config.verbose else logging.WARNING) | |
| if not self.logger.handlers: | |
| handler = logging.StreamHandler() | |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| handler.setFormatter(formatter) | |
| self.logger.addHandler(handler) | |
| def _generate_inbox(self): | |
| """Generate a new batch of emails for the episode.""" | |
| self.emails_queue = [ | |
| _generate_email( | |
| i, | |
| self.config.task_level, | |
| self.config.spam_ratio, | |
| self.config.urgent_ratio, | |
| self.config.confounding_ratio | |
| ) | |
| for i in range(self.config.num_emails) | |
| ] | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| options: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[Observation, Dict[str, Any]]: | |
| if seed is not None: | |
| self.np_random, seed = gym.utils.seeding.np_random(seed) | |
| random.seed(seed) | |
| self._generate_inbox() | |
| self.current_email_index = 0 | |
| self.total_reward = 0.0 | |
| self.start_time = time.time() | |
| self.metrics = { | |
| 'correct_actions': 0, | |
| 'incorrect_actions': 0, | |
| 'critical_failures': 0, | |
| 'steps': 0 | |
| } | |
| obs = self.get_observation_model() | |
| return obs, self.metrics | |
| def _evaluate_action(self, action_type: int, email: Email) -> Tuple[float, str, bool]: | |
| """ | |
| Evaluate if the action was appropriate for the email. | |
| Action Map: 0: Ignore, 1: Reply, 2: Forward, 3: Archive, 4: Delete | |
| Returns: (reward, feedback_message, is_correct) | |
| """ | |
| # Define Ground Truth logic | |
| if email.is_spam: | |
| expected = 4 # Delete | |
| elif email.is_urgent: | |
| expected = 2 if "forward" in email.body.lower() else 1 # Forward or Reply | |
| elif email.is_internal: | |
| expected = 1 if "?" in email.body else 3 # Reply if question, else Archive | |
| else: | |
| expected = 3 # Archive newsletter/generic | |
| is_correct = action_type == expected | |
| reward = 1.0 if is_correct else -1.0 | |
| message = "Correctly triaged." if is_correct else f"Incorrect. Ground Truth action was {expected}." | |
| # Critical failures | |
| if email.is_urgent and action_type in [0, 4]: # Ignoring or deleting urgent mail | |
| reward = -5.0 | |
| message = "CRITICAL FAILURE: Deleted or ignored urgent email!" | |
| self.metrics['critical_failures'] += 1 | |
| if email.is_spam and action_type in [1, 2]: # Replying/Forwarding spam | |
| reward = -2.0 | |
| message = "FAILURE: Engaged with spam!" | |
| return reward * self.config.reward_scale, message, is_correct | |
| def step(self, action: Union[Action, int]) -> Tuple[Observation, float, bool, bool, Dict[str, Any]]: | |
| # Gym support | |
| if isinstance(action, Action): | |
| act_val = action.action_type | |
| else: | |
| act_val = int(action) | |
| if self.current_email_index >= len(self.emails_queue): | |
| return self.get_observation_model(), 0.0, True, False, self.metrics | |
| current_email = self.emails_queue[self.current_email_index] | |
| self.metrics['steps'] += 1 | |
| # Evaluate | |
| step_rew, msg, is_correct = self._evaluate_action(act_val, current_email) | |
| self.total_reward += step_rew | |
| if is_correct: | |
| self.metrics['correct_actions'] += 1 | |
| else: | |
| self.metrics['incorrect_actions'] += 1 | |
| self.metrics['last_reward_msg'] = msg | |
| self.metrics['last_reward'] = step_rew | |
| self.current_email_index += 1 | |
| terminated = self.current_email_index >= len(self.emails_queue) | |
| truncated = False | |
| obs = self.get_observation_model() | |
| return obs, float(step_rew), terminated, truncated, self.metrics | |
| def get_observation_model(self) -> Observation: | |
| remaining = len(self.emails_queue) - self.current_email_index | |
| current_email = self.emails_queue[self.current_email_index] if remaining > 0 else None | |
| elapsed = time.time() - self.start_time | |
| return Observation( | |
| emails_remaining=remaining, | |
| current_email=current_email, | |
| time_elapsed=elapsed | |
| ) | |
| def state(self) -> EnvState: | |
| """Returns the full strictly-typed Pydantic EnvState.""" | |
| obs = self.get_observation_model() | |
| rew = Reward( | |
| step_reward=self.metrics.get('last_reward', 0.0), | |
| total_reward=self.total_reward, | |
| message=self.metrics.get('last_reward_msg', "") | |
| ) | |
| term = self.current_email_index >= len(self.emails_queue) | |
| return EnvState( | |
| observation=obs, | |
| reward=rew, | |
| terminated=term, | |
| truncated=False, | |
| info=self.metrics | |
| ) | |
| def render(self) -> None: | |
| if self.config.render_mode != 'human': | |
| return | |
| obs = self.get_observation_model() | |
| print(f"\n[{obs.emails_remaining} Emails Remaining] Total Reward: {self.total_reward:.1f}") | |
| if obs.current_email: | |
| print("="*40) | |
| print(f"From: {obs.current_email.sender}") | |
| print(f"Subject: {obs.current_email.subject}") | |
| print("-" * 40) | |
| print(f"{obs.current_email.body}") | |
| print("="*40) | |
| print("Actions: 0=Ignore, 1=Reply, 2=Forward, 3=Archive, 4=Delete") | |
| def close(self) -> None: | |
| pass | |
| def seed(self, seed: Optional[int] = None) -> int: | |
| if seed is None: | |
| seed = int(time.time() * 1000) % 2**31 | |
| self.np_random, seed = gym.utils.seeding.np_random(seed) | |
| random.seed(seed) | |
| self.config.random_seed = seed | |
| return seed | |