sql-query-reviewer / tests /test_reward.py
hellinferno's picture
improve: 20 tasks, richer keywords, enhanced reward/grader, bigram matching, compelling README
b83c8ad
from __future__ import annotations
import pytest
from sql_query_reviewer.models import GroundTruthIssue, SQLReviewAction
from server.reward import compute_reward
def _action(action_type: str, confidence: float = 0.5) -> SQLReviewAction:
if action_type == "identify_issue":
return SQLReviewAction(
action_type="identify_issue",
issue_category="syntax",
issue_description="some issue",
confidence=confidence,
)
if action_type == "suggest_fix":
return SQLReviewAction(
action_type="suggest_fix",
suggested_fix="SELECT 1;",
confidence=confidence,
)
return SQLReviewAction(action_type=action_type, confidence=confidence)
def _issue(severity: float = 0.35) -> GroundTruthIssue:
return GroundTruthIssue(
id="test_issue_001",
category="syntax",
description="A test issue.",
severity=severity,
fix="SELECT 1;",
keywords=["test"],
)
# ── identify_issue ────────────────────────────────────────────────────────────
def test_identify_issue_duplicate_returns_small_penalty() -> None:
assert compute_reward(_action("identify_issue"), _issue(), duplicate_issue=True) == pytest.approx(-0.02)
def test_identify_issue_no_match_returns_penalty() -> None:
assert compute_reward(_action("identify_issue"), None) == pytest.approx(-0.1)
def test_identify_issue_match_no_fix_zero_confidence() -> None:
# base_reward = min(0.35, 0.35) = 0.35; fix_bonus = 0; confidence_bonus = 0
# order_bonus = 0.04 * (1/(0+1)) = 0.04 β†’ total = 0.39
assert compute_reward(_action("identify_issue", confidence=0.0), _issue(0.35)) == pytest.approx(0.39)
def test_identify_issue_match_no_fix_full_confidence() -> None:
# base=0.35 + confidence_bonus=min(0.05, 1.0*0.35*0.08)=0.028 + order_bonus=0.04 β†’ 0.418
assert compute_reward(_action("identify_issue", confidence=1.0), _issue(0.35)) == pytest.approx(0.418)
def test_identify_issue_match_with_fix_zero_confidence() -> None:
# base=0.35 + fix_bonus=0.08 + order_bonus=0.04 = 0.47, capped at 0.45
assert compute_reward(_action("identify_issue", confidence=0.0), _issue(0.35), fix_valid=True) == pytest.approx(0.45)
def test_identify_issue_high_severity_capped_at_035_base() -> None:
# min(0.9, 0.35) = 0.35 + order_bonus=0.04 = 0.39
assert compute_reward(_action("identify_issue", confidence=0.0), _issue(severity=0.9)) == pytest.approx(0.39)
# ── suggest_fix ───────────────────────────────────────────────────────────────
def test_suggest_fix_without_previous_issue_is_penalized() -> None:
assert compute_reward(_action("suggest_fix"), None, has_previous_issue=False) == pytest.approx(-0.05)
def test_suggest_fix_with_previous_issue_invalid_fix() -> None:
assert compute_reward(_action("suggest_fix"), _issue(), has_previous_issue=True, fix_valid=False) == pytest.approx(0.0)
def test_suggest_fix_with_previous_issue_valid_fix() -> None:
assert compute_reward(_action("suggest_fix"), _issue(), has_previous_issue=True, fix_valid=True) == pytest.approx(0.1)
# ── approve ───────────────────────────────────────────────────────────────────
def test_approve_all_issues_found_gives_positive_reward() -> None:
assert compute_reward(_action("approve"), None, remaining_unfound=0) == pytest.approx(0.2)
def test_approve_one_issue_missed_gives_penalty() -> None:
assert compute_reward(_action("approve"), None, remaining_unfound=1) == pytest.approx(-0.15)
def test_approve_many_issues_missed_floors_at_negative_one() -> None:
# -0.15 * 7 = -1.05 β†’ floored at -1.0
assert compute_reward(_action("approve"), None, remaining_unfound=7) == pytest.approx(-1.0)
# ── request_more_context ──────────────────────────────────────────────────────
def test_request_more_context_returns_zero() -> None:
# No schema_available β†’ returns 0.0
assert compute_reward(_action("request_more_context"), None) == pytest.approx(0.0)
def test_request_more_context_with_schema_returns_penalty() -> None:
# schema_available=True β†’ returns -0.03
assert compute_reward(_action("request_more_context"), None, schema_available=True) == pytest.approx(-0.03)