sql_env / tests /unit /test_rewards.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Unit tests for training reward callables."""
from sql_env.training.rewards import (
reward_correctness,
reward_operational,
reward_progress,
)
def _completions(size: int) -> list[list[dict[str, str]]]:
return [[{"role": "assistant", "content": "QUERY: SELECT 1"}] for _ in range(size)]
def test_correctness_correct_answer() -> None:
result = reward_correctness(_completions(1), metadata=[{"correct": True}])
assert result == [1.0]
def test_correctness_wrong_answer() -> None:
result = reward_correctness(_completions(1), metadata=[{"correct": False}])
assert result == [0.0]
def test_correctness_no_answer() -> None:
result = reward_correctness(_completions(1), metadata=[{}])
assert result == [0.0]
def test_correctness_batch() -> None:
result = reward_correctness(
_completions(4),
metadata=[
{"answer_correct": True},
{"answer_correct": False},
{"correct": True},
{"correct": False},
],
)
assert result == [1.0, 0.0, 1.0, 0.0]
def test_correctness_empty_batch() -> None:
result = reward_correctness([])
assert result == []
def test_correctness_trl_compatible() -> None:
result = reward_correctness(_completions(2), metadata=[{"correct": True}, {}])
assert all(isinstance(item, float) for item in result)
def test_progress_full() -> None:
result = reward_progress(_completions(1), metadata=[{"progress": 1.0}])
assert result[0] == 1.0
def test_progress_none() -> None:
result = reward_progress(_completions(1), metadata=[{"progress": 0.0}])
assert result == [0.0]
def test_progress_partial() -> None:
result = reward_progress(_completions(1), metadata=[{"cumulative_progress": 0.4}])
assert 0.0 < result[0] < 1.0
def test_progress_normalized() -> None:
result = reward_progress(
_completions(4),
metadata=[
{"progress": -1.0},
{"progress": 0.2},
{"progress": 2.0},
{},
],
)
assert all(0.0 <= item <= 1.0 for item in result)
def test_progress_batch() -> None:
result = reward_progress(
_completions(3),
metadata=[{"progress": 0.0}, {"progress": 0.5}, {"progress": 1.0}],
)
assert result == [0.0, 0.5, 1.0]
def test_progress_trl_compatible() -> None:
result = reward_progress(_completions(2), metadata=[{}, {"progress": 0.1}])
assert all(isinstance(item, float) for item in result)
def test_operational_good_episode() -> None:
result = reward_operational(
_completions(1),
metadata=[
{
"operational_signals": [
{"exec_ok": True, "new_info": True, "repeat": False},
{"exec_ok": True, "new_info": False, "repeat": False},
]
}
],
)
assert result[0] > 0.0
def test_operational_all_errors() -> None:
result = reward_operational(
_completions(1),
metadata=[
{
"operational_signals": [
{"exec_ok": False, "new_info": False, "repeat": False},
{"exec_ok": False, "new_info": False, "repeat": False},
]
}
],
)
assert result[0] <= 0.0
def test_operational_repeat_penalty() -> None:
non_repeating = reward_operational(
_completions(1),
metadata=[
{
"operational_signals": [
{"exec_ok": True, "new_info": False, "repeat": False},
{"exec_ok": True, "new_info": False, "repeat": False},
]
}
],
)
repeating = reward_operational(
_completions(1),
metadata=[
{
"operational_signals": [
{"exec_ok": True, "new_info": False, "repeat": True},
{"exec_ok": True, "new_info": False, "repeat": True},
]
}
],
)
assert repeating[0] < non_repeating[0]
def test_operational_mixed_signals() -> None:
result = reward_operational(
_completions(1),
metadata=[
{
"operational_signals": [
{"exec_ok": True, "new_info": True, "repeat": False},
{"exec_ok": False, "new_info": False, "repeat": False},
{"exec_ok": True, "new_info": False, "repeat": True},
]
}
],
)
assert 0.0 < result[0] < 4.0
def test_operational_single_step() -> None:
result = reward_operational(
_completions(1),
metadata=[
{
"operational_signals": [
{"exec_ok": True, "new_info": False, "repeat": False}
]
}
],
)
assert isinstance(result[0], float)
def test_operational_batch() -> None:
result = reward_operational(
_completions(3),
metadata=[
{"operational": 1.0},
{"operational": -1.5},
{
"operational_signals": [
{"exec_ok": True, "new_info": True, "repeat": False},
]
},
],
)
assert len(result) == 3
assert result == [1.0, -1.5, 2.0]
def test_operational_trl_compatible() -> None:
result = reward_operational(_completions(2), metadata=[{}, {"operational": 0.5}])
assert all(isinstance(item, float) for item in result)