File size: 1,990 Bytes
7b49766
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from fastapi.testclient import TestClient

from models import DataCleanAction
from server.app import app
from server.data_clean_env_environment import DataCleanEnvironment


def test_easy_clean_solution_scores_expected_value() -> None:
    env = DataCleanEnvironment()
    env.reset(task="easy_clean")
    env.step(DataCleanAction(action_type="fill_na", column_name="age", value="0"))

    result = env.step(DataCleanAction(action_type="submit"))

    assert result.done is True
    assert result.reward == 0.99


def test_medium_clean_wrong_solution_is_not_near_perfect() -> None:
    env = DataCleanEnvironment()
    env.reset(task="medium_clean")
    env.step(DataCleanAction(action_type="fill_na", column_name="age", value="0"))
    env.step(DataCleanAction(action_type="drop_na", column_name="name"))
    env.step(DataCleanAction(action_type="drop_column", column_name="ignore_me"))

    result = env.step(DataCleanAction(action_type="submit"))

    assert result.reward < 0.99


def test_hard_clean_wrong_join_date_is_not_near_perfect() -> None:
    env = DataCleanEnvironment()
    env.reset(task="hard_clean")
    env.step(DataCleanAction(action_type="rename_column", column_name="EmployeeID", value="emp_id"))
    env.step(DataCleanAction(action_type="drop_column", column_name="Dept"))
    env.step(DataCleanAction(action_type="fill_na", column_name="Salary", value="0"))
    env.step(DataCleanAction(action_type="change_type", column_name="Salary", value="float"))
    env.step(DataCleanAction(action_type="fill_na", column_name="JoinDate", value="wrong-date"))

    result = env.step(DataCleanAction(action_type="submit"))

    assert result.reward < 0.99


def test_state_endpoint_keeps_core_state_fields() -> None:
    client = TestClient(app)
    client.post("/reset", json={"task": "easy_clean"})

    response = client.get("/state")

    assert response.status_code == 200
    payload = response.json()
    assert "episode_id" in payload
    assert "step_count" in payload