dataqa-env / tests /test_extensibility.py
varb15's picture
Upload folder using huggingface_hub
f5583f9 verified
"""Tests for the extensibility API — custom tasks and contamination rules."""
import pytest
from dataqa_env.server.tasks import (
PlantedIssue,
create_task_from_config,
register_task,
register_contamination_rule,
CONTAMINATION_RULES,
get_task,
list_tasks,
)
from dataqa_env.server.environment import DataQAEnvironment, compute_weighted_reward
from dataqa_env.models import DataQAAction
SIMPLE_CSV = "id,name,score\n1,Alice,95\n2,Bob,87\n3,Carol,92\n4,Dave,78"
class TestCreateTaskFromConfig:
def test_basic_creation(self):
task = create_task_from_config(
task_id="test_custom",
name="Test Task",
description="Test",
schema_description="id: int, name: str, score: int",
validation_rules="No missing values",
clean_csv=SIMPLE_CSV,
contaminations=[
{"rule": "missing_value", "row": 0, "col": 1},
],
)
assert task.task_id == "test_custom"
assert len(task.planted_issues) == 1
assert task.planted_issues[0].issue_type == "missing_value"
assert task.planted_issues[0].col == "name"
def test_multiple_contaminations(self):
task = create_task_from_config(
task_id="multi",
name="Multi",
description="Test",
schema_description="",
validation_rules="",
clean_csv=SIMPLE_CSV,
contaminations=[
{"rule": "missing_value", "row": 0, "col": 1},
{"rule": "missing_value", "row": 2, "col": 1},
],
)
assert len(task.planted_issues) == 2
def test_custom_difficulty_override(self):
task = create_task_from_config(
task_id="custom_diff",
name="Custom Difficulty",
description="Test",
schema_description="",
validation_rules="",
clean_csv=SIMPLE_CSV,
contaminations=[
{"rule": "missing_value", "row": 0, "col": 1, "difficulty": 2.5},
],
)
assert task.planted_issues[0].difficulty == 2.5
def test_callable_rule(self):
def custom_rule(rows, header, col_idx, row_idx, rng):
return "CORRUPTED", PlantedIssue(
row=row_idx + 1, col=header[col_idx], issue_type="wrong_type",
description="Custom corruption", difficulty=1.5,
)
task = create_task_from_config(
task_id="callable",
name="Callable Rule",
description="Test",
schema_description="",
validation_rules="",
clean_csv=SIMPLE_CSV,
contaminations=[
{"rule": custom_rule, "row": 1, "col": 2},
],
)
assert task.planted_issues[0].issue_type == "wrong_type"
assert "CORRUPTED" in task.corrupted_csv
def test_unknown_rule_raises(self):
with pytest.raises(ValueError, match="Unknown contamination rule"):
create_task_from_config(
task_id="bad",
name="Bad",
description="",
schema_description="",
validation_rules="",
clean_csv=SIMPLE_CSV,
contaminations=[{"rule": "nonexistent_rule", "row": 0, "col": 0}],
)
class TestRegisterContaminationRule:
def test_register_and_use(self):
def reverse_value(rows, header, col_idx, row_idx, rng):
val = rows[row_idx][col_idx]
return val[::-1], PlantedIssue(
row=row_idx + 1, col=header[col_idx], issue_type="format_violation",
description="Reversed value", difficulty=1.5,
)
register_contamination_rule("reverse", reverse_value)
assert "reverse" in CONTAMINATION_RULES
task = create_task_from_config(
task_id="rev_test",
name="Reverse Test",
description="",
schema_description="",
validation_rules="",
clean_csv=SIMPLE_CSV,
contaminations=[{"rule": "reverse", "row": 0, "col": 1}],
)
assert task.planted_issues[0].issue_type == "format_violation"
# "Alice" reversed is "ecilA"
assert "ecilA" in task.corrupted_csv
# Cleanup
del CONTAMINATION_RULES["reverse"]
class TestRegisterTask:
def test_register_and_get(self):
task = create_task_from_config(
task_id="registered",
name="Registered Task",
description="Test registered task",
schema_description="id: int, name: str",
validation_rules="No missing values",
clean_csv=SIMPLE_CSV,
contaminations=[{"rule": "missing_value", "row": 1, "col": 1}],
)
register_task("registered", lambda seed: task)
assert "registered" in list_tasks()
fetched = get_task("registered")
assert fetched.task_id == "registered"
assert len(fetched.planted_issues) == 1
# Cleanup
from dataqa_env.server.tasks import TASK_REGISTRY
del TASK_REGISTRY["registered"]
class TestCustomTaskInEnvironment:
def test_full_lifecycle_identify_only(self):
"""Custom task works end-to-end with identify-only."""
task = create_task_from_config(
task_id="e2e_custom",
name="E2E Custom",
description="End-to-end test",
schema_description="id: int, name: str, score: int",
validation_rules="No missing values",
clean_csv=SIMPLE_CSV,
contaminations=[
{"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
{"rule": "whitespace_value", "row": 2, "col": 1, "difficulty": 2.5},
],
)
register_task("e2e_custom", lambda seed: task)
env = DataQAEnvironment()
obs = env.reset(task_id="e2e_custom")
assert obs.num_issues_hint == 2
action = DataQAAction(
issues=[i.to_key() for i in task.planted_issues],
task_id="e2e_custom",
)
obs = env.step(action)
assert obs.done is True
assert obs.reward >= 0.999
from dataqa_env.server.tasks import TASK_REGISTRY
del TASK_REGISTRY["e2e_custom"]
def test_full_lifecycle_identify_and_fix(self):
"""Custom task works end-to-end with both identify and fix."""
task = create_task_from_config(
task_id="e2e_fix",
name="E2E Fix",
description="End-to-end test with fixes",
schema_description="id: int, name: str, score: int",
validation_rules="No missing values",
clean_csv=SIMPLE_CSV,
contaminations=[
{"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
],
)
register_task("e2e_fix", lambda seed: task)
env = DataQAEnvironment()
env.reset(task_id="e2e_fix")
# Submit issues + fix
action = DataQAAction(
issues=[task.planted_issues[0].to_key()],
fixes=["row:1,col:name,fix:Alice"], # clean value is "Alice"
task_id="e2e_fix",
)
obs = env.step(action)
assert obs.done is True
assert obs.metadata["fix_score"] > 0.0
assert obs.metadata["combined_reward"] > 0.0
from dataqa_env.server.tasks import TASK_REGISTRY
del TASK_REGISTRY["e2e_fix"]