dataqa-env / tests /test_inference.py
avanigupta's picture
add fix stage+demo
c3002ad
"""Tests for the inference script's parsing, prompt building, and log format."""
import pytest
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from inference import parse_llm_response, parse_fix_response, build_user_prompt, log_start, log_step, log_end
class TestParseLLMResponse:
def test_standard_format(self):
response = "row:1,col:name,issue:missing_value\nrow:2,col:salary,issue:wrong_type"
issues = parse_llm_response(response)
assert len(issues) == 2
assert "row:1,col:name,issue:missing_value" in issues
def test_numbered_list(self):
response = "1. row:1,col:name,issue:missing_value\n2. row:2,col:salary,issue:wrong_type"
issues = parse_llm_response(response)
assert len(issues) == 2
def test_bullet_list(self):
response = "- row:1,col:name,issue:missing_value\n* row:2,col:salary,issue:wrong_type"
issues = parse_llm_response(response)
assert len(issues) == 2
def test_equals_delimiter(self):
response = "row=1,col=name,issue=missing_value"
issues = parse_llm_response(response)
assert len(issues) == 1
assert issues[0] == "row:1,col:name,issue:missing_value"
def test_mixed_case(self):
response = "Row:1,Col:Name,Issue:Missing_Value"
issues = parse_llm_response(response)
assert len(issues) == 1
assert issues[0] == "row:1,col:name,issue:missing_value"
def test_empty_response(self):
assert parse_llm_response("") == []
assert parse_llm_response(" ") == []
def test_garbage_lines_skipped(self):
response = "Here are the issues:\nrow:1,col:name,issue:missing_value\nNo more issues."
issues = parse_llm_response(response)
assert len(issues) == 1
def test_deduplication_not_applied(self):
response = "row:1,col:name,issue:missing_value\nrow:1,col:name,issue:missing_value"
issues = parse_llm_response(response)
assert len(issues) == 2
def test_with_column_variant(self):
response = "row:1,column:name,issue:missing_value"
issues = parse_llm_response(response)
assert len(issues) == 1
class TestParseFixResponse:
def test_standard_format(self):
response = "row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000"
fixes = parse_fix_response(response)
assert len(fixes) == 2
assert "row:4,col:name,fix:David Kim" in fixes
def test_numbered_list(self):
response = "1. row:4,col:name,fix:David Kim\n2. row:7,col:salary,fix:75000"
fixes = parse_fix_response(response)
assert len(fixes) == 2
def test_with_special_chars(self):
response = "row:1,col:email,fix:alice.chen@company.com"
fixes = parse_fix_response(response)
assert len(fixes) == 1
assert "alice.chen@company.com" in fixes[0]
def test_empty_response(self):
assert parse_fix_response("") == []
def test_date_fix(self):
response = "row:12,col:order_date,fix:2024-01-26"
fixes = parse_fix_response(response)
assert len(fixes) == 1
def test_ignores_issue_lines(self):
response = "row:4,col:name,issue:missing_value\nrow:4,col:name,fix:David Kim"
fixes = parse_fix_response(response)
assert len(fixes) == 1 # only the fix line
class TestBuildUserPrompt:
def test_includes_all_fields(self):
obs = {
"task_description": "Find issues",
"schema_description": "col: int",
"validation_rules": "no nulls",
"dataset_csv": "a,b\n1,2",
"num_issues_hint": 3,
"feedback": "",
}
prompt = build_user_prompt(obs)
assert "Find issues" in prompt
assert "col: int" in prompt
assert "no nulls" in prompt
assert "a,b" in prompt
assert "3 issues" in prompt
def test_includes_feedback_on_retry(self):
obs = {
"task_description": "Find issues",
"schema_description": "",
"validation_rules": "",
"dataset_csv": "a\n1",
"num_issues_hint": 0,
"feedback": "Step 1/3: You missed 2 issues",
}
prompt = build_user_prompt(obs)
assert "FEEDBACK" in prompt
assert "missed 2" in prompt
def test_excludes_reset_feedback(self):
obs = {
"task_description": "",
"schema_description": "",
"validation_rules": "",
"dataset_csv": "",
"num_issues_hint": 0,
"feedback": "Environment reset. Start inspecting.",
}
prompt = build_user_prompt(obs)
assert "FEEDBACK" not in prompt
def test_include_fixes_flag(self):
obs = {
"task_description": "Find issues",
"schema_description": "",
"validation_rules": "",
"dataset_csv": "a\n1",
"num_issues_hint": 0,
"feedback": "",
}
prompt = build_user_prompt(obs, include_fixes=True)
assert "fix" in prompt.lower()
class TestLogFormat:
"""Verify stdout log format matches hackathon evaluation requirements."""
def test_log_start_format(self, capsys):
log_start(task="easy", env="dataqa_env", model="test-model")
out = capsys.readouterr().out.strip()
assert out == "[START] task=easy env=dataqa_env model=test-model"
def test_log_step_format(self, capsys):
log_step(step=1, action="row:1,col:name,issue:missing_value", reward=0.50, done=False, error=None)
out = capsys.readouterr().out.strip()
assert out == "[STEP] step=1 action=row:1,col:name,issue:missing_value reward=0.50 done=false error=null"
def test_log_step_with_error(self, capsys):
log_step(step=2, action="none", reward=0.00, done=True, error="timeout")
out = capsys.readouterr().out.strip()
assert "error=timeout" in out
assert "done=true" in out
def test_log_end_format(self, capsys):
log_end(success=True, steps=3, score=0.85, rewards=[0.25, 0.50, 0.85])
out = capsys.readouterr().out.strip()
assert out == "[END] success=true steps=3 score=0.850 rewards=0.25,0.50,0.85"
def test_log_end_failure(self, capsys):
log_end(success=False, steps=1, score=0.0, rewards=[0.0])
out = capsys.readouterr().out.strip()
assert "success=false" in out
assert "score=0.000" in out
def test_reward_format_2_decimal(self, capsys):
log_step(step=1, action="test", reward=0.123456, done=False, error=None)
out = capsys.readouterr().out.strip()
assert "reward=0.12" in out
def test_no_newlines_within_line(self, capsys):
log_start(task="easy", env="dataqa_env", model="model")
log_step(step=1, action="act", reward=0.0, done=False, error=None)
log_end(success=False, steps=1, score=0.0, rewards=[0.0])
out = capsys.readouterr().out
lines = [l for l in out.split("\n") if l.strip()]
assert len(lines) == 3
assert lines[0].startswith("[START]")
assert lines[1].startswith("[STEP]")
assert lines[2].startswith("[END]")