Spaces:
Sleeping
Sleeping
File size: 4,249 Bytes
3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 4b77608 3eb9552 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """
Test Suite for OpenEnv Email Triage
Comprehensive tests covering:
- Environment initialization
- API compliance (Pydantic models)
- Email generation dynamics
- Reward computation
- Termination conditions
- Configuration system
"""
import pytest
import os
import sys
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from openenv.core.env import OpenEnv
from openenv.core.config import EnvConfig
from openenv.core.models import Observation, Action, Reward, EnvState
class TestEnvInitialization:
def test_default_initialization(self):
env = OpenEnv()
assert env is not None
assert isinstance(env.config, EnvConfig)
env.close()
def test_custom_config_initialization(self):
config = EnvConfig(num_emails=10, verbose=False, random_seed=42)
env = OpenEnv(config=config)
assert env.config.num_emails == 10
assert env.config.random_seed == 42
env.close()
def test_invalid_config(self):
with pytest.raises(ValueError):
config = EnvConfig(num_emails=-5)
config.validate()
class TestAPICompliance:
def test_reset_returns_observation(self):
env = OpenEnv()
obs, info = env.reset()
assert isinstance(obs, Observation)
assert isinstance(info, dict)
assert obs.emails_remaining == env.config.num_emails
env.close()
def test_step_returns_correct_format(self):
env = OpenEnv()
env.reset()
act = Action(action_type=0)
obs, reward, terminated, truncated, info = env.step(act)
assert isinstance(obs, Observation)
assert isinstance(reward, float)
assert isinstance(terminated, bool)
assert isinstance(truncated, bool)
assert isinstance(info, dict)
env.close()
def test_state_returns_envstate(self):
env = OpenEnv()
env.reset()
state = env.state()
assert isinstance(state, EnvState)
assert isinstance(state.observation, Observation)
assert isinstance(state.reward, Reward)
env.close()
class TestRewardFunction:
def test_correct_action_reward(self):
env = OpenEnv(config=EnvConfig(num_emails=1))
obs, _ = env.reset()
email = obs.current_email
# Determine ground truth locally to force correct action
if email.is_spam: act = 4
elif email.is_urgent:
act = 2 if "forward" in email.body.lower() else 1
elif email.is_internal:
act = 1 if "?" in email.body else 3
else: act = 3
_, reward, _, _, _ = env.step(Action(action_type=act))
assert reward == 1.0
env.close()
def test_critical_failure_penalty(self):
env = OpenEnv(config=EnvConfig(num_emails=10, urgent_ratio=1.0, spam_ratio=0.0))
obs, _ = env.reset()
# It's an urgent email
assert obs.current_email.is_urgent == True
# Ignored or deleted
_, reward, _, _, _ = env.step(Action(action_type=0))
assert reward == -5.0
env.close()
class TestTerminationConditions:
def test_episode_termination(self):
config = EnvConfig(num_emails=2, verbose=False)
env = OpenEnv(config=config)
env.reset()
_, _, term1, _, _ = env.step(Action(action_type=0))
assert not term1
_, _, term2, _, _ = env.step(Action(action_type=0))
assert term2
# further steps should safely do nothing and return 0 rew
obs, rem_rew, term3, _, _ = env.step(Action(action_type=0))
assert term3
assert rem_rew == 0.0
assert obs.emails_remaining == 0
env.close()
class TestConfiguration:
def test_config_save_load(self, tmp_path):
config = EnvConfig(num_emails=55, task_level="hard")
filepath = tmp_path / "config.json"
config.save(str(filepath))
loaded = EnvConfig.load(str(filepath))
assert loaded.num_emails == 55
assert loaded.task_level == "hard"
if __name__ == "__main__":
pytest.main([__file__, "-v"])
|