dataqa_env / tests /test_tasks.py
varb15's picture
Upload folder using huggingface_hub
17adce2 verified
"""Tests for task definitions, data corruption, and issue planting."""
import pytest
from dataqa_env.server.tasks import (
PlantedIssue,
Task,
create_task_easy,
create_task_medium,
create_task_hard,
get_task,
list_tasks,
_csv_to_rows,
_rows_to_csv,
)
class TestPlantedIssue:
def test_to_key(self):
issue = PlantedIssue(row=3, col="salary", issue_type="missing_value", description="test")
assert issue.to_key() == "row:3,col:salary,issue:missing_value"
def test_difficulty_default(self):
issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test")
assert issue.difficulty == 1.0
def test_difficulty_custom(self):
issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test", difficulty=3.0)
assert issue.difficulty == 3.0
class TestCSVHelpers:
def test_roundtrip(self):
csv_text = "a,b,c\n1,2,3\n4,5,6"
rows = _csv_to_rows(csv_text)
assert len(rows) == 3
result = _rows_to_csv(rows)
assert "1,2,3" in result
def test_empty_csv(self):
rows = _csv_to_rows("a,b\n")
assert len(rows) == 1 # header only
class TestTaskEasy:
@pytest.fixture
def task(self):
return create_task_easy()
def test_task_id(self, task):
assert task.task_id == "easy"
def test_has_6_issues(self, task):
assert len(task.planted_issues) == 6
def test_issue_types(self, task):
types = {i.issue_type for i in task.planted_issues}
assert "missing_value" in types
assert "wrong_type" in types
assert "duplicate_row" in types
assert "format_violation" in types
assert "inconsistent_value" in types
def test_corrupted_csv_differs_from_clean(self, task):
assert task.corrupted_csv != task.clean_csv
def test_issue_keys_unique(self, task):
keys = [i.to_key() for i in task.planted_issues]
assert len(keys) == len(set(keys))
def test_max_steps(self, task):
assert task.max_steps == 3
def test_corrupted_csv_has_more_rows(self, task):
clean_rows = _csv_to_rows(task.clean_csv)
corrupt_rows = _csv_to_rows(task.corrupted_csv)
assert len(corrupt_rows) > len(clean_rows) # duplicate row added
def test_difficulty_weights(self, task):
for issue in task.planted_issues:
assert 1.0 <= issue.difficulty <= 3.0
class TestTaskMedium:
@pytest.fixture
def task(self):
return create_task_medium()
def test_task_id(self, task):
assert task.task_id == "medium"
def test_has_8_issues(self, task):
assert len(task.planted_issues) == 8
def test_issue_types(self, task):
types = {i.issue_type for i in task.planted_issues}
assert "inconsistent_value" in types
assert "format_violation" in types
assert "wrong_type" in types
def test_issue_keys_unique(self, task):
keys = [i.to_key() for i in task.planted_issues]
assert len(keys) == len(set(keys))
def test_difficulty_weights(self, task):
for issue in task.planted_issues:
assert 1.0 <= issue.difficulty <= 3.0
class TestTaskHard:
@pytest.fixture
def task(self):
return create_task_hard()
def test_task_id(self, task):
assert task.task_id == "hard"
def test_has_10_issues(self, task):
assert len(task.planted_issues) == 10
def test_issue_types(self, task):
types = {i.issue_type for i in task.planted_issues}
assert "inconsistent_value" in types
assert "format_violation" in types
assert "statistical_outlier" in types
assert "out_of_range" in types
def test_has_high_difficulty_issues(self, task):
hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5]
assert len(hard_issues) >= 2 # data leakage, GPU outlier, whitespace
def test_issue_keys_unique(self, task):
keys = [i.to_key() for i in task.planted_issues]
assert len(keys) == len(set(keys))
class TestTaskAlignment:
@pytest.fixture
def task(self):
return create_task_hard() # reuse import, we'll import alignment below
def test_alignment_task(self):
from dataqa_env.server.tasks import get_task
task = get_task("alignment")
assert task.task_id == "alignment"
assert len(task.planted_issues) == 12
def test_alignment_issue_types(self):
from dataqa_env.server.tasks import get_task
task = get_task("alignment")
types = {i.issue_type for i in task.planted_issues}
assert "inconsistent_value" in types # factual errors, mismatches, hallucinations
assert "missing_value" in types # truncated, whitespace-only
assert "duplicate_row" in types # duplicate instruction
def test_alignment_has_high_difficulty(self):
from dataqa_env.server.tasks import get_task
task = get_task("alignment")
hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5]
assert len(hard_issues) >= 3 # hallucinated citation, harmful advice, factual error
def test_alignment_issue_keys_unique(self):
from dataqa_env.server.tasks import get_task
task = get_task("alignment")
keys = [i.to_key() for i in task.planted_issues]
assert len(keys) == len(set(keys))
def test_alignment_corrupted_differs(self):
from dataqa_env.server.tasks import get_task
task = get_task("alignment")
assert task.corrupted_csv != task.clean_csv
def test_alignment_in_env(self):
from dataqa_env.server.environment import DataQAEnvironment
from dataqa_env.models import DataQAAction
env = DataQAEnvironment()
obs = env.reset(task_id="alignment")
assert obs.num_issues_hint == 12
# Perfect submission
from dataqa_env.server.tasks import get_task
task = get_task("alignment")
action = DataQAAction(issues=[i.to_key() for i in task.planted_issues], task_id="alignment")
obs = env.step(action)
assert obs.reward >= 0.99
class TestTaskModeration:
def test_moderation_task(self):
from dataqa_env.server.tasks import get_task
task = get_task("moderation")
assert task.task_id == "moderation"
assert len(task.planted_issues) == 10
def test_moderation_issue_types(self):
from dataqa_env.server.tasks import get_task
task = get_task("moderation")
types = {i.issue_type for i in task.planted_issues}
assert "inconsistent_value" in types
assert "out_of_range" in types
assert "missing_value" in types
assert "duplicate_row" in types
def test_moderation_in_env(self):
from dataqa_env.server.environment import DataQAEnvironment
from dataqa_env.models import DataQAAction
from dataqa_env.server.tasks import get_task
env = DataQAEnvironment()
obs = env.reset(task_id="moderation")
assert obs.num_issues_hint == 10
task = get_task("moderation")
action = DataQAAction(issues=[i.to_key() for i in task.planted_issues], task_id="moderation")
obs = env.step(action)
assert obs.reward >= 0.99
def test_moderation_deterministic(self):
from dataqa_env.server.environment import DataQAEnvironment
from dataqa_env.models import DataQAAction
env = DataQAEnvironment()
env.reset(task_id="moderation", seed=42)
a = DataQAAction(issues=["row:16,col:hate,issue:inconsistent_value"], task_id="moderation")
r1 = env.step(a).reward
env.reset(task_id="moderation", seed=42)
r2 = env.step(a).reward
assert r1 == r2
class TestTaskRegistry:
def test_list_tasks(self):
tasks = list_tasks()
assert set(tasks) == {"easy", "medium", "hard", "alignment", "coding", "toolcalling", "moderation"}
def test_get_task_easy(self):
task = get_task("easy")
assert task.task_id == "easy"
def test_get_task_medium(self):
task = get_task("medium")
assert task.task_id == "medium"
def test_get_task_hard(self):
task = get_task("hard")
assert task.task_id == "hard"
def test_get_task_unknown_raises(self):
with pytest.raises(ValueError, match="Unknown task"):
get_task("nonexistent")
def test_seed_determinism(self):
t1 = get_task("easy", seed=42)
t2 = get_task("easy", seed=42)
assert t1.corrupted_csv == t2.corrupted_csv
assert [i.to_key() for i in t1.planted_issues] == [i.to_key() for i in t2.planted_issues]