Spaces:
Sleeping
Sleeping
| """ | |
| Test Suite for OpenEnv Email Triage | |
| Comprehensive tests covering: | |
| - Environment initialization | |
| - API compliance (Pydantic models) | |
| - Email generation dynamics | |
| - Reward computation | |
| - Termination conditions | |
| - Configuration system | |
| """ | |
| import pytest | |
| import os | |
| import sys | |
| # Add project root to path | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from openenv.core.env import OpenEnv | |
| from openenv.core.config import EnvConfig | |
| from openenv.core.models import Observation, Action, Reward, EnvState | |
| class TestEnvInitialization: | |
| def test_default_initialization(self): | |
| env = OpenEnv() | |
| assert env is not None | |
| assert isinstance(env.config, EnvConfig) | |
| env.close() | |
| def test_custom_config_initialization(self): | |
| config = EnvConfig(num_emails=10, verbose=False, random_seed=42) | |
| env = OpenEnv(config=config) | |
| assert env.config.num_emails == 10 | |
| assert env.config.random_seed == 42 | |
| env.close() | |
| def test_invalid_config(self): | |
| with pytest.raises(ValueError): | |
| config = EnvConfig(num_emails=-5) | |
| config.validate() | |
| class TestAPICompliance: | |
| def test_reset_returns_observation(self): | |
| env = OpenEnv() | |
| obs, info = env.reset() | |
| assert isinstance(obs, Observation) | |
| assert isinstance(info, dict) | |
| assert obs.emails_remaining == env.config.num_emails | |
| env.close() | |
| def test_step_returns_correct_format(self): | |
| env = OpenEnv() | |
| env.reset() | |
| act = Action(action_type=0) | |
| obs, reward, terminated, truncated, info = env.step(act) | |
| assert isinstance(obs, Observation) | |
| assert isinstance(reward, float) | |
| assert isinstance(terminated, bool) | |
| assert isinstance(truncated, bool) | |
| assert isinstance(info, dict) | |
| env.close() | |
| def test_state_returns_envstate(self): | |
| env = OpenEnv() | |
| env.reset() | |
| state = env.state() | |
| assert isinstance(state, EnvState) | |
| assert isinstance(state.observation, Observation) | |
| assert isinstance(state.reward, Reward) | |
| env.close() | |
| class TestRewardFunction: | |
| def test_correct_action_reward(self): | |
| env = OpenEnv(config=EnvConfig(num_emails=1)) | |
| obs, _ = env.reset() | |
| email = obs.current_email | |
| # Determine ground truth locally to force correct action | |
| if email.is_spam: act = 4 | |
| elif email.is_urgent: | |
| act = 2 if "forward" in email.body.lower() else 1 | |
| elif email.is_internal: | |
| act = 1 if "?" in email.body else 3 | |
| else: act = 3 | |
| _, reward, _, _, _ = env.step(Action(action_type=act)) | |
| assert reward == 1.0 | |
| env.close() | |
| def test_critical_failure_penalty(self): | |
| env = OpenEnv(config=EnvConfig(num_emails=10, urgent_ratio=1.0, spam_ratio=0.0)) | |
| obs, _ = env.reset() | |
| # It's an urgent email | |
| assert obs.current_email.is_urgent == True | |
| # Ignored or deleted | |
| _, reward, _, _, _ = env.step(Action(action_type=0)) | |
| assert reward == -5.0 | |
| env.close() | |
| class TestTerminationConditions: | |
| def test_episode_termination(self): | |
| config = EnvConfig(num_emails=2, verbose=False) | |
| env = OpenEnv(config=config) | |
| env.reset() | |
| _, _, term1, _, _ = env.step(Action(action_type=0)) | |
| assert not term1 | |
| _, _, term2, _, _ = env.step(Action(action_type=0)) | |
| assert term2 | |
| # further steps should safely do nothing and return 0 rew | |
| obs, rem_rew, term3, _, _ = env.step(Action(action_type=0)) | |
| assert term3 | |
| assert rem_rew == 0.0 | |
| assert obs.emails_remaining == 0 | |
| env.close() | |
| class TestConfiguration: | |
| def test_config_save_load(self, tmp_path): | |
| config = EnvConfig(num_emails=55, task_level="hard") | |
| filepath = tmp_path / "config.json" | |
| config.save(str(filepath)) | |
| loaded = EnvConfig.load(str(filepath)) | |
| assert loaded.num_emails == 55 | |
| assert loaded.task_level == "hard" | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |