Spaces:
Sleeping
Sleeping
| """Tests for the inference script's parsing, prompt building, and log format.""" | |
| import pytest | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| from inference import parse_llm_response, parse_fix_response, build_user_prompt, log_start, log_step, log_end | |
| class TestParseLLMResponse: | |
| def test_standard_format(self): | |
| response = "row:1,col:name,issue:missing_value\nrow:2,col:salary,issue:wrong_type" | |
| issues = parse_llm_response(response) | |
| assert len(issues) == 2 | |
| assert "row:1,col:name,issue:missing_value" in issues | |
| def test_numbered_list(self): | |
| response = "1. row:1,col:name,issue:missing_value\n2. row:2,col:salary,issue:wrong_type" | |
| issues = parse_llm_response(response) | |
| assert len(issues) == 2 | |
| def test_bullet_list(self): | |
| response = "- row:1,col:name,issue:missing_value\n* row:2,col:salary,issue:wrong_type" | |
| issues = parse_llm_response(response) | |
| assert len(issues) == 2 | |
| def test_equals_delimiter(self): | |
| response = "row=1,col=name,issue=missing_value" | |
| issues = parse_llm_response(response) | |
| assert len(issues) == 1 | |
| assert issues[0] == "row:1,col:name,issue:missing_value" | |
| def test_mixed_case(self): | |
| response = "Row:1,Col:Name,Issue:Missing_Value" | |
| issues = parse_llm_response(response) | |
| assert len(issues) == 1 | |
| assert issues[0] == "row:1,col:name,issue:missing_value" | |
| def test_empty_response(self): | |
| assert parse_llm_response("") == [] | |
| assert parse_llm_response(" ") == [] | |
| def test_garbage_lines_skipped(self): | |
| response = "Here are the issues:\nrow:1,col:name,issue:missing_value\nNo more issues." | |
| issues = parse_llm_response(response) | |
| assert len(issues) == 1 | |
| def test_deduplication_not_applied(self): | |
| response = "row:1,col:name,issue:missing_value\nrow:1,col:name,issue:missing_value" | |
| issues = parse_llm_response(response) | |
| assert len(issues) == 2 | |
| def test_with_column_variant(self): | |
| response = "row:1,column:name,issue:missing_value" | |
| issues = parse_llm_response(response) | |
| assert len(issues) == 1 | |
| class TestParseFixResponse: | |
| def test_standard_format(self): | |
| response = "row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000" | |
| fixes = parse_fix_response(response) | |
| assert len(fixes) == 2 | |
| assert "row:4,col:name,fix:David Kim" in fixes | |
| def test_numbered_list(self): | |
| response = "1. row:4,col:name,fix:David Kim\n2. row:7,col:salary,fix:75000" | |
| fixes = parse_fix_response(response) | |
| assert len(fixes) == 2 | |
| def test_with_special_chars(self): | |
| response = "row:1,col:email,fix:alice.chen@company.com" | |
| fixes = parse_fix_response(response) | |
| assert len(fixes) == 1 | |
| assert "alice.chen@company.com" in fixes[0] | |
| def test_empty_response(self): | |
| assert parse_fix_response("") == [] | |
| def test_date_fix(self): | |
| response = "row:12,col:order_date,fix:2024-01-26" | |
| fixes = parse_fix_response(response) | |
| assert len(fixes) == 1 | |
| def test_ignores_issue_lines(self): | |
| response = "row:4,col:name,issue:missing_value\nrow:4,col:name,fix:David Kim" | |
| fixes = parse_fix_response(response) | |
| assert len(fixes) == 1 # only the fix line | |
| class TestBuildUserPrompt: | |
| def test_includes_all_fields(self): | |
| obs = { | |
| "task_description": "Find issues", | |
| "schema_description": "col: int", | |
| "validation_rules": "no nulls", | |
| "dataset_csv": "a,b\n1,2", | |
| "num_issues_hint": 3, | |
| "feedback": "", | |
| } | |
| prompt = build_user_prompt(obs) | |
| assert "Find issues" in prompt | |
| assert "col: int" in prompt | |
| assert "no nulls" in prompt | |
| assert "a,b" in prompt | |
| assert "3 issues" in prompt | |
| def test_includes_feedback_on_retry(self): | |
| obs = { | |
| "task_description": "Find issues", | |
| "schema_description": "", | |
| "validation_rules": "", | |
| "dataset_csv": "a\n1", | |
| "num_issues_hint": 0, | |
| "feedback": "Step 1/3: You missed 2 issues", | |
| } | |
| prompt = build_user_prompt(obs) | |
| assert "FEEDBACK" in prompt | |
| assert "missed 2" in prompt | |
| def test_excludes_reset_feedback(self): | |
| obs = { | |
| "task_description": "", | |
| "schema_description": "", | |
| "validation_rules": "", | |
| "dataset_csv": "", | |
| "num_issues_hint": 0, | |
| "feedback": "Environment reset. Start inspecting.", | |
| } | |
| prompt = build_user_prompt(obs) | |
| assert "FEEDBACK" not in prompt | |
| def test_include_fixes_flag(self): | |
| obs = { | |
| "task_description": "Find issues", | |
| "schema_description": "", | |
| "validation_rules": "", | |
| "dataset_csv": "a\n1", | |
| "num_issues_hint": 0, | |
| "feedback": "", | |
| } | |
| prompt = build_user_prompt(obs, include_fixes=True) | |
| assert "fix" in prompt.lower() | |
| class TestLogFormat: | |
| """Verify stdout log format matches hackathon evaluation requirements.""" | |
| def test_log_start_format(self, capsys): | |
| log_start(task="easy", env="dataqa_env", model="test-model") | |
| out = capsys.readouterr().out.strip() | |
| assert out == "[START] task=easy env=dataqa_env model=test-model" | |
| def test_log_step_format(self, capsys): | |
| log_step(step=1, action="row:1,col:name,issue:missing_value", reward=0.50, done=False, error=None) | |
| out = capsys.readouterr().out.strip() | |
| assert out == "[STEP] step=1 action=row:1,col:name,issue:missing_value reward=0.50 done=false error=null" | |
| def test_log_step_with_error(self, capsys): | |
| log_step(step=2, action="none", reward=0.00, done=True, error="timeout") | |
| out = capsys.readouterr().out.strip() | |
| assert "error=timeout" in out | |
| assert "done=true" in out | |
| def test_log_end_format(self, capsys): | |
| log_end(success=True, steps=3, score=0.85, rewards=[0.25, 0.50, 0.85]) | |
| out = capsys.readouterr().out.strip() | |
| assert out == "[END] success=true steps=3 score=0.850 rewards=0.25,0.50,0.85" | |
| def test_log_end_failure(self, capsys): | |
| log_end(success=False, steps=1, score=0.0, rewards=[0.0]) | |
| out = capsys.readouterr().out.strip() | |
| assert "success=false" in out | |
| assert "score=0.000" in out | |
| def test_reward_format_2_decimal(self, capsys): | |
| log_step(step=1, action="test", reward=0.123456, done=False, error=None) | |
| out = capsys.readouterr().out.strip() | |
| assert "reward=0.12" in out | |
| def test_no_newlines_within_line(self, capsys): | |
| log_start(task="easy", env="dataqa_env", model="model") | |
| log_step(step=1, action="act", reward=0.0, done=False, error=None) | |
| log_end(success=False, steps=1, score=0.0, rewards=[0.0]) | |
| out = capsys.readouterr().out | |
| lines = [l for l in out.split("\n") if l.strip()] | |
| assert len(lines) == 3 | |
| assert lines[0].startswith("[START]") | |
| assert lines[1].startswith("[STEP]") | |
| assert lines[2].startswith("[END]") | |