Spaces:
Sleeping
Sleeping
File size: 1,990 Bytes
7b49766 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 | from fastapi.testclient import TestClient
from models import DataCleanAction
from server.app import app
from server.data_clean_env_environment import DataCleanEnvironment
def test_easy_clean_solution_scores_expected_value() -> None:
env = DataCleanEnvironment()
env.reset(task="easy_clean")
env.step(DataCleanAction(action_type="fill_na", column_name="age", value="0"))
result = env.step(DataCleanAction(action_type="submit"))
assert result.done is True
assert result.reward == 0.99
def test_medium_clean_wrong_solution_is_not_near_perfect() -> None:
env = DataCleanEnvironment()
env.reset(task="medium_clean")
env.step(DataCleanAction(action_type="fill_na", column_name="age", value="0"))
env.step(DataCleanAction(action_type="drop_na", column_name="name"))
env.step(DataCleanAction(action_type="drop_column", column_name="ignore_me"))
result = env.step(DataCleanAction(action_type="submit"))
assert result.reward < 0.99
def test_hard_clean_wrong_join_date_is_not_near_perfect() -> None:
env = DataCleanEnvironment()
env.reset(task="hard_clean")
env.step(DataCleanAction(action_type="rename_column", column_name="EmployeeID", value="emp_id"))
env.step(DataCleanAction(action_type="drop_column", column_name="Dept"))
env.step(DataCleanAction(action_type="fill_na", column_name="Salary", value="0"))
env.step(DataCleanAction(action_type="change_type", column_name="Salary", value="float"))
env.step(DataCleanAction(action_type="fill_na", column_name="JoinDate", value="wrong-date"))
result = env.step(DataCleanAction(action_type="submit"))
assert result.reward < 0.99
def test_state_endpoint_keeps_core_state_fields() -> None:
client = TestClient(app)
client.post("/reset", json={"task": "easy_clean"})
response = client.get("/state")
assert response.status_code == 200
payload = response.json()
assert "episode_id" in payload
assert "step_count" in payload
|