File size: 4,249 Bytes
3eb9552
4b77608
3eb9552
 
 
4b77608
 
3eb9552
 
 
 
 
 
 
 
 
 
 
 
4b77608
 
 
3eb9552
 
 
 
 
 
 
4b77608
3eb9552
4b77608
3eb9552
4b77608
3eb9552
 
4b77608
3eb9552
 
4b77608
 
 
3eb9552
4b77608
3eb9552
 
 
4b77608
3eb9552
4b77608
3eb9552
4b77608
3eb9552
 
4b77608
3eb9552
4b77608
 
3eb9552
4b77608
3eb9552
 
 
 
 
 
4b77608
3eb9552
 
4b77608
 
 
 
3eb9552
 
 
4b77608
 
 
 
 
 
 
 
 
 
 
 
3eb9552
4b77608
 
3eb9552
 
4b77608
 
 
 
 
3eb9552
4b77608
 
 
3eb9552
 
 
4b77608
 
3eb9552
 
 
4b77608
 
3eb9552
4b77608
 
3eb9552
4b77608
 
 
 
 
3eb9552
 
 
 
4b77608
3eb9552
 
 
4b77608
 
 
3eb9552
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Test Suite for OpenEnv Email Triage

Comprehensive tests covering:
- Environment initialization
- API compliance (Pydantic models)
- Email generation dynamics
- Reward computation
- Termination conditions
- Configuration system
"""

import pytest
import os
import sys

# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from openenv.core.env import OpenEnv
from openenv.core.config import EnvConfig
from openenv.core.models import Observation, Action, Reward, EnvState

class TestEnvInitialization:
    def test_default_initialization(self):
        env = OpenEnv()
        assert env is not None
        assert isinstance(env.config, EnvConfig)
        env.close()

    def test_custom_config_initialization(self):
        config = EnvConfig(num_emails=10, verbose=False, random_seed=42)
        env = OpenEnv(config=config)
        assert env.config.num_emails == 10
        assert env.config.random_seed == 42
        env.close()

    def test_invalid_config(self):
        with pytest.raises(ValueError):
            config = EnvConfig(num_emails=-5)
            config.validate()
            
class TestAPICompliance:
    def test_reset_returns_observation(self):
        env = OpenEnv()
        obs, info = env.reset()
        
        assert isinstance(obs, Observation)
        assert isinstance(info, dict)
        assert obs.emails_remaining == env.config.num_emails
        env.close()

    def test_step_returns_correct_format(self):
        env = OpenEnv()
        env.reset()
        
        act = Action(action_type=0)
        obs, reward, terminated, truncated, info = env.step(act)
        
        assert isinstance(obs, Observation)
        assert isinstance(reward, float)
        assert isinstance(terminated, bool)
        assert isinstance(truncated, bool)
        assert isinstance(info, dict)
        env.close()
        
    def test_state_returns_envstate(self):
        env = OpenEnv()
        env.reset()
        state = env.state()
        assert isinstance(state, EnvState)
        assert isinstance(state.observation, Observation)
        assert isinstance(state.reward, Reward)
        env.close()

class TestRewardFunction:
    def test_correct_action_reward(self):
        env = OpenEnv(config=EnvConfig(num_emails=1))
        obs, _ = env.reset()
        email = obs.current_email
        
        # Determine ground truth locally to force correct action
        if email.is_spam: act = 4
        elif email.is_urgent: 
            act = 2 if "forward" in email.body.lower() else 1
        elif email.is_internal:
            act = 1 if "?" in email.body else 3
        else: act = 3
            
        _, reward, _, _, _ = env.step(Action(action_type=act))
        assert reward == 1.0
        env.close()
        
    def test_critical_failure_penalty(self):
        env = OpenEnv(config=EnvConfig(num_emails=10, urgent_ratio=1.0, spam_ratio=0.0))
        obs, _ = env.reset()
        # It's an urgent email
        assert obs.current_email.is_urgent == True
        
        # Ignored or deleted
        _, reward, _, _, _ = env.step(Action(action_type=0))
        assert reward == -5.0
        env.close()

class TestTerminationConditions:
    def test_episode_termination(self):
        config = EnvConfig(num_emails=2, verbose=False)
        env = OpenEnv(config=config)
        env.reset()
        
        _, _, term1, _, _ = env.step(Action(action_type=0))
        assert not term1
        
        _, _, term2, _, _ = env.step(Action(action_type=0))
        assert term2
        
        # further steps should safely do nothing and return 0 rew
        obs, rem_rew, term3, _, _ = env.step(Action(action_type=0))
        assert term3
        assert rem_rew == 0.0
        assert obs.emails_remaining == 0
        env.close()

class TestConfiguration:
    def test_config_save_load(self, tmp_path):
        config = EnvConfig(num_emails=55, task_level="hard")
        filepath = tmp_path / "config.json"
        config.save(str(filepath))
        
        loaded = EnvConfig.load(str(filepath))
        assert loaded.num_emails == 55
        assert loaded.task_level == "hard"

if __name__ == "__main__":
    pytest.main([__file__, "-v"])