Spaces:
Sleeping
Sleeping
| """Tests for task definitions, data corruption, and issue planting.""" | |
| import pytest | |
| from dataqa_env.server.tasks import ( | |
| PlantedIssue, | |
| Task, | |
| create_task_easy, | |
| create_task_medium, | |
| create_task_hard, | |
| get_task, | |
| list_tasks, | |
| _csv_to_rows, | |
| _rows_to_csv, | |
| ) | |
| class TestPlantedIssue: | |
| def test_to_key(self): | |
| issue = PlantedIssue(row=3, col="salary", issue_type="missing_value", description="test") | |
| assert issue.to_key() == "row:3,col:salary,issue:missing_value" | |
| def test_difficulty_default(self): | |
| issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test") | |
| assert issue.difficulty == 1.0 | |
| def test_difficulty_custom(self): | |
| issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test", difficulty=3.0) | |
| assert issue.difficulty == 3.0 | |
| class TestCSVHelpers: | |
| def test_roundtrip(self): | |
| csv_text = "a,b,c\n1,2,3\n4,5,6" | |
| rows = _csv_to_rows(csv_text) | |
| assert len(rows) == 3 | |
| result = _rows_to_csv(rows) | |
| assert "1,2,3" in result | |
| def test_empty_csv(self): | |
| rows = _csv_to_rows("a,b\n") | |
| assert len(rows) == 1 # header only | |
| class TestTaskEasy: | |
| def task(self): | |
| return create_task_easy() | |
| def test_task_id(self, task): | |
| assert task.task_id == "easy" | |
| def test_has_6_issues(self, task): | |
| assert len(task.planted_issues) == 6 | |
| def test_issue_types(self, task): | |
| types = {i.issue_type for i in task.planted_issues} | |
| assert "missing_value" in types | |
| assert "wrong_type" in types | |
| assert "duplicate_row" in types | |
| assert "format_violation" in types | |
| assert "inconsistent_value" in types | |
| def test_corrupted_csv_differs_from_clean(self, task): | |
| assert task.corrupted_csv != task.clean_csv | |
| def test_issue_keys_unique(self, task): | |
| keys = [i.to_key() for i in task.planted_issues] | |
| assert len(keys) == len(set(keys)) | |
| def test_max_steps(self, task): | |
| assert task.max_steps == 3 | |
| def test_corrupted_csv_has_more_rows(self, task): | |
| clean_rows = _csv_to_rows(task.clean_csv) | |
| corrupt_rows = _csv_to_rows(task.corrupted_csv) | |
| assert len(corrupt_rows) > len(clean_rows) # duplicate row added | |
| def test_difficulty_weights(self, task): | |
| for issue in task.planted_issues: | |
| assert 1.0 <= issue.difficulty <= 3.0 | |
| class TestTaskMedium: | |
| def task(self): | |
| return create_task_medium() | |
| def test_task_id(self, task): | |
| assert task.task_id == "medium" | |
| def test_has_8_issues(self, task): | |
| assert len(task.planted_issues) == 8 | |
| def test_issue_types(self, task): | |
| types = {i.issue_type for i in task.planted_issues} | |
| assert "inconsistent_value" in types | |
| assert "format_violation" in types | |
| assert "wrong_type" in types | |
| def test_issue_keys_unique(self, task): | |
| keys = [i.to_key() for i in task.planted_issues] | |
| assert len(keys) == len(set(keys)) | |
| def test_difficulty_weights(self, task): | |
| for issue in task.planted_issues: | |
| assert 1.0 <= issue.difficulty <= 3.0 | |
| class TestTaskHard: | |
| def task(self): | |
| return create_task_hard() | |
| def test_task_id(self, task): | |
| assert task.task_id == "hard" | |
| def test_has_10_issues(self, task): | |
| assert len(task.planted_issues) == 10 | |
| def test_issue_types(self, task): | |
| types = {i.issue_type for i in task.planted_issues} | |
| assert "inconsistent_value" in types | |
| assert "format_violation" in types | |
| assert "statistical_outlier" in types | |
| assert "out_of_range" in types | |
| def test_has_high_difficulty_issues(self, task): | |
| hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5] | |
| assert len(hard_issues) >= 2 # data leakage, GPU outlier, whitespace | |
| def test_issue_keys_unique(self, task): | |
| keys = [i.to_key() for i in task.planted_issues] | |
| assert len(keys) == len(set(keys)) | |
| class TestTaskAlignment: | |
| def task(self): | |
| return create_task_hard() # reuse import, we'll import alignment below | |
| def test_alignment_task(self): | |
| from dataqa_env.server.tasks import get_task | |
| task = get_task("alignment") | |
| assert task.task_id == "alignment" | |
| assert len(task.planted_issues) == 12 | |
| def test_alignment_issue_types(self): | |
| from dataqa_env.server.tasks import get_task | |
| task = get_task("alignment") | |
| types = {i.issue_type for i in task.planted_issues} | |
| assert "inconsistent_value" in types # factual errors, mismatches, hallucinations | |
| assert "missing_value" in types # truncated, whitespace-only | |
| assert "duplicate_row" in types # duplicate instruction | |
| def test_alignment_has_high_difficulty(self): | |
| from dataqa_env.server.tasks import get_task | |
| task = get_task("alignment") | |
| hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5] | |
| assert len(hard_issues) >= 3 # hallucinated citation, harmful advice, factual error | |
| def test_alignment_issue_keys_unique(self): | |
| from dataqa_env.server.tasks import get_task | |
| task = get_task("alignment") | |
| keys = [i.to_key() for i in task.planted_issues] | |
| assert len(keys) == len(set(keys)) | |
| def test_alignment_corrupted_differs(self): | |
| from dataqa_env.server.tasks import get_task | |
| task = get_task("alignment") | |
| assert task.corrupted_csv != task.clean_csv | |
| def test_alignment_in_env(self): | |
| from dataqa_env.server.environment import DataQAEnvironment | |
| from dataqa_env.models import DataQAAction | |
| env = DataQAEnvironment() | |
| obs = env.reset(task_id="alignment") | |
| assert obs.num_issues_hint == 12 | |
| # Perfect submission | |
| from dataqa_env.server.tasks import get_task | |
| task = get_task("alignment") | |
| action = DataQAAction(issues=[i.to_key() for i in task.planted_issues], task_id="alignment") | |
| obs = env.step(action) | |
| assert obs.reward >= 0.99 | |
| class TestTaskModeration: | |
| def test_moderation_task(self): | |
| from dataqa_env.server.tasks import get_task | |
| task = get_task("moderation") | |
| assert task.task_id == "moderation" | |
| assert len(task.planted_issues) == 10 | |
| def test_moderation_issue_types(self): | |
| from dataqa_env.server.tasks import get_task | |
| task = get_task("moderation") | |
| types = {i.issue_type for i in task.planted_issues} | |
| assert "inconsistent_value" in types | |
| assert "out_of_range" in types | |
| assert "missing_value" in types | |
| assert "duplicate_row" in types | |
| def test_moderation_in_env(self): | |
| from dataqa_env.server.environment import DataQAEnvironment | |
| from dataqa_env.models import DataQAAction | |
| from dataqa_env.server.tasks import get_task | |
| env = DataQAEnvironment() | |
| obs = env.reset(task_id="moderation") | |
| assert obs.num_issues_hint == 10 | |
| task = get_task("moderation") | |
| action = DataQAAction(issues=[i.to_key() for i in task.planted_issues], task_id="moderation") | |
| obs = env.step(action) | |
| assert obs.reward >= 0.99 | |
| def test_moderation_deterministic(self): | |
| from dataqa_env.server.environment import DataQAEnvironment | |
| from dataqa_env.models import DataQAAction | |
| env = DataQAEnvironment() | |
| env.reset(task_id="moderation", seed=42) | |
| a = DataQAAction(issues=["row:16,col:hate,issue:inconsistent_value"], task_id="moderation") | |
| r1 = env.step(a).reward | |
| env.reset(task_id="moderation", seed=42) | |
| r2 = env.step(a).reward | |
| assert r1 == r2 | |
| class TestTaskRegistry: | |
| def test_list_tasks(self): | |
| tasks = list_tasks() | |
| assert set(tasks) == {"easy", "medium", "hard", "alignment", "coding", "toolcalling", "moderation"} | |
| def test_get_task_easy(self): | |
| task = get_task("easy") | |
| assert task.task_id == "easy" | |
| def test_get_task_medium(self): | |
| task = get_task("medium") | |
| assert task.task_id == "medium" | |
| def test_get_task_hard(self): | |
| task = get_task("hard") | |
| assert task.task_id == "hard" | |
| def test_get_task_unknown_raises(self): | |
| with pytest.raises(ValueError, match="Unknown task"): | |
| get_task("nonexistent") | |
| def test_seed_determinism(self): | |
| t1 = get_task("easy", seed=42) | |
| t2 = get_task("easy", seed=42) | |
| assert t1.corrupted_csv == t2.corrupted_csv | |
| assert [i.to_key() for i in t1.planted_issues] == [i.to_key() for i in t2.planted_issues] | |