Spaces:
Sleeping
Sleeping
File size: 3,640 Bytes
0b81240 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | """
tests/test_env.py — local smoke tests for DataSelectEnv
Run with:
python tests/test_env.py
"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
from env import DataSelectEnv
from models import Action
BASE_CFG = {
"data": {"n_samples": 1500, "n_features": 20, "n_informative": 5,
"n_redundant": 5, "flip_y": 0.1},
"budget": 300,
"max_steps": 15,
"alpha": 0.2,
}
def test_reset_reproducible():
"""Two resets with the same seed must return identical observations."""
env = DataSelectEnv(BASE_CFG, seed=42)
obs1 = env.reset()
obs2 = env.reset()
assert obs1.model_dump() == obs2.model_dump(), "reset() not reproducible!"
print("PASS: reset() is reproducible")
def test_step_runs():
"""A full episode must complete without errors."""
env = DataSelectEnv(BASE_CFG, seed=42)
obs = env.reset()
done = False
steps = 0
action = Action(
action_type="select_batch",
batch_size=10,
strategy_weights={"uncertainty": 0.4, "diversity": 0.4, "random": 0.2},
)
while not done:
obs, reward, done, info = env.step(action)
steps += 1
assert isinstance(reward, float)
assert "noise_ratio" in info
print(f"PASS: episode completed in {steps} steps, final_perf={obs.current_performance:.4f}")
def test_get_state():
"""get_state() must return valid EnvState."""
env = DataSelectEnv(BASE_CFG, seed=42)
env.reset()
s = env.get_state()
assert s.step_count == 0
assert s.remaining_budget == BASE_CFG["budget"]
assert not s.done
print("PASS: get_state() after reset is correct")
def test_noise_mask_sync():
"""noise_mask must stay in sync with X_pool throughout episode."""
env = DataSelectEnv(BASE_CFG, seed=42)
env.reset()
action = Action(
action_type="select_batch",
batch_size=10,
strategy_weights={"uncertainty": 0.4, "diversity": 0.4, "random": 0.2},
)
for _ in range(5):
obs, _, done, _ = env.step(action)
s = env._episode_state
assert len(s.noise_mask) == len(s.X_pool), "noise_mask out of sync!"
if done:
break
print("PASS: noise_mask stays in sync with X_pool")
def test_strategies():
"""Run 3 strategies and verify balanced beats uncertainty-only."""
strategies = {
"balanced": {"uncertainty": 0.4, "diversity": 0.4, "random": 0.2},
"uncertainty": {"uncertainty": 0.95, "diversity": 0.03, "random": 0.02},
"random": {"uncertainty": 0.0, "diversity": 0.0, "random": 1.0},
}
results = {}
for name, weights in strategies.items():
env = DataSelectEnv(BASE_CFG, seed=42)
obs = env.reset()
done = False
action = Action(action_type="select_batch", batch_size=10, strategy_weights=weights)
while not done:
obs, _, done, _ = env.step(action)
results[name] = obs.current_performance
print(f" {name:15s} final_perf={obs.current_performance:.4f}")
assert results["balanced"] >= results["uncertainty"], \
"Balanced should outperform uncertainty-only!"
print("PASS: balanced strategy outperforms uncertainty-only")
if __name__ == "__main__":
print("Running DataSelectEnv smoke tests...\n")
test_reset_reproducible()
test_step_runs()
test_get_state()
test_noise_mask_sync()
test_strategies()
print("\nAll tests passed.") |