dataqa-env / tests /test_environment.py
varb15's picture
Upload folder using huggingface_hub
c5b540e verified
"""Tests for the DataQA environment (reset, step, scoring, two-phase identify+fix)."""
import pytest
from dataqa_env.server.environment import (
DataQAEnvironment,
parse_issue_key,
parse_fix,
compute_f1,
compute_weighted_reward,
grade_fixes,
IDENTIFY_WEIGHT,
FIX_WEIGHT,
)
from dataqa_env.models import DataQAAction
from dataqa_env.server.tasks import PlantedIssue, create_task_easy, create_task_medium
# ──────────────────────────────────────────────────────
# Issue parsing
# ──────────────────────────────────────────────────────
class TestParseIssueKey:
def test_standard_format(self):
assert parse_issue_key("row:3,col:salary,issue:missing_value") == "row:3,col:salary,issue:missing_value"
def test_with_equals(self):
assert parse_issue_key("row=3,col=salary,issue=missing_value") == "row:3,col:salary,issue:missing_value"
def test_case_insensitive(self):
assert parse_issue_key("Row:3,Col:Salary,Issue:Missing_Value") == "row:3,col:salary,issue:missing_value"
def test_with_spaces(self):
assert parse_issue_key("row: 3, col: salary, issue: missing_value") == "row:3,col:salary,issue:missing_value"
def test_unparseable(self):
assert parse_issue_key("this is garbage") is None
def test_partial_match(self):
assert parse_issue_key("row:3,col:salary") is None
def test_empty_string(self):
assert parse_issue_key("") is None
def test_semicolon_separator(self):
result = parse_issue_key("row:3;col:salary;issue:missing_value")
assert result == "row:3,col:salary,issue:missing_value"
# ──────────────────────────────────────────────────────
# Fix parsing
# ──────────────────────────────────────────────────────
class TestParseFix:
def test_standard_format(self):
result = parse_fix("row:4,col:name,fix:Alice Chen")
assert result == (4, "name", "Alice Chen")
def test_with_equals(self):
result = parse_fix("row=4,col=name,fix=Alice Chen")
assert result == (4, "name", "Alice Chen")
def test_numeric_fix(self):
result = parse_fix("row:7,col:salary,fix:75000")
assert result == (7, "salary", "75000")
def test_date_fix(self):
result = parse_fix("row:12,col:order_date,fix:2024-01-26")
assert result == (12, "order_date", "2024-01-26")
def test_case_insensitive(self):
result = parse_fix("Row:4,Col:Name,Fix:Alice Chen")
assert result == (4, "name", "Alice Chen")
def test_unparseable(self):
assert parse_fix("garbage") is None
assert parse_fix("row:4,col:name") is None
def test_fix_with_special_chars(self):
result = parse_fix("row:1,col:email,fix:alice.chen@company.com")
assert result == (1, "email", "alice.chen@company.com")
# ──────────────────────────────────────────────────────
# F1 scoring
# ──────────────────────────────────────────────────────
class TestComputeF1:
def test_perfect_match(self):
keys = {"row:1,col:a,issue:missing_value"}
result = compute_f1(keys, keys)
assert result["f1"] == 1.0
def test_no_reported_no_planted(self):
result = compute_f1(set(), set())
assert result["f1"] == 1.0
def test_no_reported_some_planted(self):
planted = {"row:1,col:a,issue:missing_value"}
result = compute_f1(set(), planted)
assert result["f1"] == 0.0
assert result["fn"] == 1
def test_all_false_positives(self):
reported = {"row:99,col:x,issue:wrong_type"}
planted = {"row:1,col:a,issue:missing_value"}
result = compute_f1(reported, planted)
assert result["f1"] == 0.0
def test_partial_match(self):
reported = {"row:1,col:a,issue:missing_value", "row:2,col:b,issue:wrong_type"}
planted = {"row:1,col:a,issue:missing_value", "row:3,col:c,issue:duplicate_row"}
result = compute_f1(reported, planted)
assert result["tp"] == 1
assert result["fp"] == 1
assert result["fn"] == 1
assert 0 < result["f1"] < 1
def test_precision_recall_calculation(self):
reported = {"a", "b", "c"}
planted = {"a", "b", "d"}
result = compute_f1(reported, planted)
assert result["precision"] == pytest.approx(2 / 3)
assert result["recall"] == pytest.approx(2 / 3)
# ──────────────────────────────────────────────────────
# Weighted reward
# ──────────────────────────────────────────────────────
class TestComputeWeightedReward:
def test_perfect_match(self):
issues = [
PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0),
PlantedIssue(row=2, col="b", issue_type="wrong_type", description="", difficulty=3.0),
]
reported = {i.to_key() for i in issues}
result = compute_weighted_reward(reported, issues)
assert result["weighted_reward"] == 1.0
def test_empty_both(self):
result = compute_weighted_reward(set(), [])
assert result["weighted_reward"] == 1.0
def test_no_reported(self):
issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=2.0)]
result = compute_weighted_reward(set(), issues)
assert result["weighted_reward"] == 0.0
def test_hard_issue_worth_more(self):
easy = PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)
hard = PlantedIssue(row=2, col="b", issue_type="statistical_outlier", description="", difficulty=3.0)
issues = [easy, hard]
hard_found = compute_weighted_reward({hard.to_key()}, issues)
easy_found = compute_weighted_reward({easy.to_key()}, issues)
assert hard_found["weighted_reward"] > easy_found["weighted_reward"]
def test_false_positives_reduce_reward(self):
issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)]
correct = {issues[0].to_key()}
with_fp = correct | {"row:99,col:x,issue:wrong_type"}
r_correct = compute_weighted_reward(correct, issues)
r_with_fp = compute_weighted_reward(with_fp, issues)
assert r_correct["weighted_reward"] > r_with_fp["weighted_reward"]
# ──────────────────────────────────────────────────────
# Fix grading
# ──────────────────────────────────────────────────────
class TestGradeFixes:
@pytest.fixture
def easy_task(self):
return create_task_easy()
def test_no_fixes_no_issues(self):
from dataqa_env.server.tasks import Task
task = Task(task_id="empty", name="", description="", schema_description="",
validation_rules="", clean_csv="a\n1")
result = grade_fixes([], task)
assert result["fix_score"] == 1.0
def test_no_fixes_submitted(self, easy_task):
result = grade_fixes([], easy_task)
assert result["fix_score"] == 0.0
assert result["fixes_attempted"] == 0
def test_exact_fix_for_missing_name(self, easy_task):
# Row 4 has empty name β€” clean value is "David Kim"
fixes = [(4, "name", "David Kim")]
result = grade_fixes(fixes, easy_task)
assert result["fix_score"] > 0.0
assert result["fixes_correct"] == 1
def test_exact_fix_for_wrong_type_salary(self, easy_task):
# Row 7 has "seventy-five thousand" β€” clean value is "75000"
fixes = [(7, "salary", "75000")]
result = grade_fixes(fixes, easy_task)
assert result["fixes_correct"] == 1
def test_misspelling_fix(self, easy_task):
# Row 11 has department "Engneering" β€” fix to "Engineering"
fixes = [(11, "department", "Engineering")]
result = grade_fixes(fixes, easy_task)
assert result["fixes_correct"] == 1
def test_wrong_value_for_issue_cell(self, easy_task):
# Row 4 name is empty β€” propose wrong name
fixes = [(4, "name", "Wrong Person")]
result = grade_fixes(fixes, easy_task)
assert result["fixes_partial"] == 1 # correct cell, wrong value
assert result["fix_score"] > 0.0 # gets partial credit
def test_fix_for_non_issue_cell(self, easy_task):
# Row 1 col name is fine β€” no issue there
fixes = [(1, "name", "Some Name")]
result = grade_fixes(fixes, easy_task)
assert result["fixes_wrong"] == 1
assert result["fix_score"] == 0.0
def test_multiple_fixes_best_wins(self, easy_task):
# Submit two fixes for same cell β€” best one should count
fixes = [
(4, "name", "Wrong Person"), # partial credit
(4, "name", "David Kim"), # exact match
]
result = grade_fixes(fixes, easy_task)
assert result["fixes_correct"] >= 1
def test_all_fixes_correct(self, easy_task):
# Fix deterministic issues with exact values
fixes = [
(4, "name", "David Kim"), # inferred from email
(7, "salary", "75000"), # type conversion
(11, "department", "Engineering"), # spelling fix
(15, "email", "oscar.rivera@company.com"), # pattern match
(12, "start_date", "2022-11-03"), # date format fix
]
result = grade_fixes(fixes, easy_task)
assert result["fix_score"] > 0.7
def test_fix_score_bounded(self, easy_task):
fixes = [(4, "name", "David Kim"), (99, "x", "bad")]
result = grade_fixes(fixes, easy_task)
assert 0.0 <= result["fix_score"] <= 1.0
# ──────────────────────────────────────────────────────
# Full environment lifecycle
# ──────────────────────────────────────────────────────
class TestDataQAEnvironment:
@pytest.fixture
def env(self):
return DataQAEnvironment()
def test_reset_returns_observation(self, env):
obs = env.reset(task_id="easy")
assert obs.dataset_csv
assert obs.schema_description
assert obs.validation_rules
assert obs.task_description
assert obs.num_issues_hint == 6
assert obs.max_steps == 3
assert obs.done is False
assert obs.reward < 0.01 # clamped to 0.001, not exactly 0.0
assert "fix" in obs.feedback.lower() # mentions fix phase
def test_reset_medium(self, env):
obs = env.reset(task_id="medium")
assert obs.num_issues_hint == 8
def test_reset_hard(self, env):
obs = env.reset(task_id="hard")
assert obs.num_issues_hint == 10
def test_step_identify_only(self, env):
"""Backward compatible: only issues, no fixes."""
env.reset(task_id="easy")
# Submit all 6 correct issues for easy task
from dataqa_env.server.tasks import get_task
task = get_task("easy")
action = DataQAAction(
issues=[i.to_key() for i in task.planted_issues],
task_id="easy",
)
obs = env.step(action)
assert obs.done is True
assert obs.reward >= 0.999
def test_step_with_fixes_increases_reward(self, env):
"""Submitting correct fixes should produce high combined reward."""
env.reset(task_id="easy")
from dataqa_env.server.tasks import get_task
task = get_task("easy")
action = DataQAAction(
issues=[i.to_key() for i in task.planted_issues],
fixes=[
"row:4,col:name,fix:David Kim",
"row:7,col:salary,fix:75000",
"row:9,col:department,fix:Engineering",
],
task_id="easy",
)
obs = env.step(action)
assert obs.metadata["combined_reward"] > 0.7
def test_step_with_partial_issues(self, env):
env.reset(task_id="easy")
action = DataQAAction(
issues=["row:4,col:name,issue:missing_value"],
task_id="easy",
)
obs = env.step(action)
assert 0 < obs.reward < 1.0
assert obs.done is False
def test_step_with_no_issues(self, env):
env.reset(task_id="easy")
action = DataQAAction(issues=[], task_id="easy")
obs = env.step(action)
assert obs.reward < 0.01 # clamped to 0.001, not exactly 0.0
def test_step_exhausts_max_steps(self, env):
env.reset(task_id="easy")
for _ in range(3):
action = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
obs = env.step(action)
assert obs.done is True
def test_auto_reset_on_step(self, env):
action = DataQAAction(
issues=["row:4,col:name,issue:missing_value"],
task_id="easy",
)
obs = env.step(action)
assert obs.task_id == "easy"
def test_state_tracking(self, env):
env.reset(task_id="easy")
assert env.state.task_id == "easy"
assert env.state.current_step == 0
assert env.state.best_score == 0.0
action = DataQAAction(issues=["row:4,col:name,issue:missing_value"], task_id="easy")
env.step(action)
assert env.state.current_step == 1
assert env.state.best_score > 0.0
def test_best_score_monotonic(self, env):
env.reset(task_id="easy")
action1 = DataQAAction(
issues=["row:4,col:name,issue:missing_value", "row:7,col:salary,issue:wrong_type"],
task_id="easy",
)
env.step(action1)
score_after_1 = env.state.best_score
action2 = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
env.step(action2)
assert env.state.best_score >= score_after_1
def test_metadata_includes_both_phases(self, env):
env.reset(task_id="easy")
action = DataQAAction(
issues=["row:4,col:name,issue:missing_value"],
fixes=["row:4,col:name,fix:David Kim"],
task_id="easy",
)
obs = env.step(action)
m = obs.metadata
assert "identify_f1" in m
assert "identify_score" in m
assert "fix_score" in m
assert "combined_reward" in m
assert "tp" in m
assert "fixes_correct" in m
assert "fixes_attempted" in m
def test_parse_error_in_feedback(self, env):
env.reset(task_id="easy")
action = DataQAAction(issues=["garbage input"], task_id="easy")
obs = env.step(action)
assert "Parse error" in obs.feedback
def test_concurrent_sessions_flag(self):
assert DataQAEnvironment.SUPPORTS_CONCURRENT_SESSIONS is True
def test_reward_between_0_and_1(self, env):
"""Hackathon requirement: scores must be 0.0-1.0."""
env.reset(task_id="hard")
for _ in range(3):
action = DataQAAction(
issues=["row:1,col:x,issue:wrong_type", "row:99,col:y,issue:missing_value"],
fixes=["row:1,col:x,fix:wrong"],
task_id="hard",
)
obs = env.step(action)
assert 0.0 <= obs.reward <= 1.0
def test_combined_reward_weights(self, env):
"""Verify combined = IDENTIFY_WEIGHT * identify + FIX_WEIGHT * fix."""
env.reset(task_id="easy")
action = DataQAAction(
issues=["row:4,col:name,issue:missing_value"],
fixes=["row:4,col:name,fix:David Kim"],
task_id="easy",
)
obs = env.step(action)
m = obs.metadata
expected = IDENTIFY_WEIGHT * m["identify_score"] + FIX_WEIGHT * m["fix_score"]
assert abs(m["combined_reward"] - expected) < 0.01
def test_fix_feedback_shown_when_fixes_submitted(self, env):
env.reset(task_id="easy")
action = DataQAAction(
issues=["row:4,col:name,issue:missing_value"],
fixes=["row:4,col:name,fix:David Kim"],
task_id="easy",
)
obs = env.step(action)
assert "Fix Proposals" in obs.feedback
assert "Combined Reward" in obs.feedback
def test_no_fix_penalty_when_no_fixes_submitted(self, env):
"""If agent submits no fixes, reward = identify_score (no penalty)."""
env.reset(task_id="easy")
from dataqa_env.server.tasks import get_task
task = get_task("easy")
action = DataQAAction(
issues=[i.to_key() for i in task.planted_issues],
task_id="easy",
)
obs = env.step(action)
assert obs.reward >= 0.99
assert obs.metadata["combined_reward"] == obs.metadata["identify_score"]