File size: 3,640 Bytes
0b81240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""

tests/test_env.py — local smoke tests for DataSelectEnv



Run with:

    python tests/test_env.py

"""

import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import numpy as np
from env import DataSelectEnv
from models import Action

BASE_CFG = {
    "data": {"n_samples": 1500, "n_features": 20, "n_informative": 5,
             "n_redundant": 5, "flip_y": 0.1},
    "budget":    300,
    "max_steps": 15,
    "alpha":     0.2,
}


def test_reset_reproducible():
    """Two resets with the same seed must return identical observations."""
    env = DataSelectEnv(BASE_CFG, seed=42)
    obs1 = env.reset()
    obs2 = env.reset()
    assert obs1.model_dump() == obs2.model_dump(), "reset() not reproducible!"
    print("PASS: reset() is reproducible")


def test_step_runs():
    """A full episode must complete without errors."""
    env = DataSelectEnv(BASE_CFG, seed=42)
    obs = env.reset()
    done = False
    steps = 0
    action = Action(
        action_type="select_batch",
        batch_size=10,
        strategy_weights={"uncertainty": 0.4, "diversity": 0.4, "random": 0.2},
    )
    while not done:
        obs, reward, done, info = env.step(action)
        steps += 1
        assert isinstance(reward, float)
        assert "noise_ratio" in info
    print(f"PASS: episode completed in {steps} steps, final_perf={obs.current_performance:.4f}")


def test_get_state():
    """get_state() must return valid EnvState."""
    env = DataSelectEnv(BASE_CFG, seed=42)
    env.reset()
    s = env.get_state()
    assert s.step_count == 0
    assert s.remaining_budget == BASE_CFG["budget"]
    assert not s.done
    print("PASS: get_state() after reset is correct")


def test_noise_mask_sync():
    """noise_mask must stay in sync with X_pool throughout episode."""
    env = DataSelectEnv(BASE_CFG, seed=42)
    env.reset()
    action = Action(
        action_type="select_batch",
        batch_size=10,
        strategy_weights={"uncertainty": 0.4, "diversity": 0.4, "random": 0.2},
    )
    for _ in range(5):
        obs, _, done, _ = env.step(action)
        s = env._episode_state
        assert len(s.noise_mask) == len(s.X_pool), "noise_mask out of sync!"
        if done:
            break
    print("PASS: noise_mask stays in sync with X_pool")


def test_strategies():
    """Run 3 strategies and verify balanced beats uncertainty-only."""
    strategies = {
        "balanced":    {"uncertainty": 0.4, "diversity": 0.4, "random": 0.2},
        "uncertainty": {"uncertainty": 0.95, "diversity": 0.03, "random": 0.02},
        "random":      {"uncertainty": 0.0,  "diversity": 0.0,  "random": 1.0},
    }
    results = {}
    for name, weights in strategies.items():
        env = DataSelectEnv(BASE_CFG, seed=42)
        obs = env.reset()
        done = False
        action = Action(action_type="select_batch", batch_size=10, strategy_weights=weights)
        while not done:
            obs, _, done, _ = env.step(action)
        results[name] = obs.current_performance
        print(f"  {name:15s} final_perf={obs.current_performance:.4f}")

    assert results["balanced"] >= results["uncertainty"], \
        "Balanced should outperform uncertainty-only!"
    print("PASS: balanced strategy outperforms uncertainty-only")


if __name__ == "__main__":
    print("Running DataSelectEnv smoke tests...\n")
    test_reset_reproducible()
    test_step_runs()
    test_get_state()
    test_noise_mask_sync()
    test_strategies()
    print("\nAll tests passed.")