| """Unit tests for training reward callables.""" |
|
|
| from sql_env.training.rewards import ( |
| reward_correctness, |
| reward_operational, |
| reward_progress, |
| ) |
|
|
|
|
| def _completions(size: int) -> list[list[dict[str, str]]]: |
| return [[{"role": "assistant", "content": "QUERY: SELECT 1"}] for _ in range(size)] |
|
|
|
|
| def test_correctness_correct_answer() -> None: |
| result = reward_correctness(_completions(1), metadata=[{"correct": True}]) |
| assert result == [1.0] |
|
|
|
|
| def test_correctness_wrong_answer() -> None: |
| result = reward_correctness(_completions(1), metadata=[{"correct": False}]) |
| assert result == [0.0] |
|
|
|
|
| def test_correctness_no_answer() -> None: |
| result = reward_correctness(_completions(1), metadata=[{}]) |
| assert result == [0.0] |
|
|
|
|
| def test_correctness_batch() -> None: |
| result = reward_correctness( |
| _completions(4), |
| metadata=[ |
| {"answer_correct": True}, |
| {"answer_correct": False}, |
| {"correct": True}, |
| {"correct": False}, |
| ], |
| ) |
| assert result == [1.0, 0.0, 1.0, 0.0] |
|
|
|
|
| def test_correctness_empty_batch() -> None: |
| result = reward_correctness([]) |
| assert result == [] |
|
|
|
|
| def test_correctness_trl_compatible() -> None: |
| result = reward_correctness(_completions(2), metadata=[{"correct": True}, {}]) |
| assert all(isinstance(item, float) for item in result) |
|
|
|
|
| def test_progress_full() -> None: |
| result = reward_progress(_completions(1), metadata=[{"progress": 1.0}]) |
| assert result[0] == 1.0 |
|
|
|
|
| def test_progress_none() -> None: |
| result = reward_progress(_completions(1), metadata=[{"progress": 0.0}]) |
| assert result == [0.0] |
|
|
|
|
| def test_progress_partial() -> None: |
| result = reward_progress(_completions(1), metadata=[{"cumulative_progress": 0.4}]) |
| assert 0.0 < result[0] < 1.0 |
|
|
|
|
| def test_progress_normalized() -> None: |
| result = reward_progress( |
| _completions(4), |
| metadata=[ |
| {"progress": -1.0}, |
| {"progress": 0.2}, |
| {"progress": 2.0}, |
| {}, |
| ], |
| ) |
| assert all(0.0 <= item <= 1.0 for item in result) |
|
|
|
|
| def test_progress_batch() -> None: |
| result = reward_progress( |
| _completions(3), |
| metadata=[{"progress": 0.0}, {"progress": 0.5}, {"progress": 1.0}], |
| ) |
| assert result == [0.0, 0.5, 1.0] |
|
|
|
|
| def test_progress_trl_compatible() -> None: |
| result = reward_progress(_completions(2), metadata=[{}, {"progress": 0.1}]) |
| assert all(isinstance(item, float) for item in result) |
|
|
|
|
| def test_operational_good_episode() -> None: |
| result = reward_operational( |
| _completions(1), |
| metadata=[ |
| { |
| "operational_signals": [ |
| {"exec_ok": True, "new_info": True, "repeat": False}, |
| {"exec_ok": True, "new_info": False, "repeat": False}, |
| ] |
| } |
| ], |
| ) |
| assert result[0] > 0.0 |
|
|
|
|
| def test_operational_all_errors() -> None: |
| result = reward_operational( |
| _completions(1), |
| metadata=[ |
| { |
| "operational_signals": [ |
| {"exec_ok": False, "new_info": False, "repeat": False}, |
| {"exec_ok": False, "new_info": False, "repeat": False}, |
| ] |
| } |
| ], |
| ) |
| assert result[0] <= 0.0 |
|
|
|
|
| def test_operational_repeat_penalty() -> None: |
| non_repeating = reward_operational( |
| _completions(1), |
| metadata=[ |
| { |
| "operational_signals": [ |
| {"exec_ok": True, "new_info": False, "repeat": False}, |
| {"exec_ok": True, "new_info": False, "repeat": False}, |
| ] |
| } |
| ], |
| ) |
| repeating = reward_operational( |
| _completions(1), |
| metadata=[ |
| { |
| "operational_signals": [ |
| {"exec_ok": True, "new_info": False, "repeat": True}, |
| {"exec_ok": True, "new_info": False, "repeat": True}, |
| ] |
| } |
| ], |
| ) |
| assert repeating[0] < non_repeating[0] |
|
|
|
|
| def test_operational_mixed_signals() -> None: |
| result = reward_operational( |
| _completions(1), |
| metadata=[ |
| { |
| "operational_signals": [ |
| {"exec_ok": True, "new_info": True, "repeat": False}, |
| {"exec_ok": False, "new_info": False, "repeat": False}, |
| {"exec_ok": True, "new_info": False, "repeat": True}, |
| ] |
| } |
| ], |
| ) |
| assert 0.0 < result[0] < 4.0 |
|
|
|
|
| def test_operational_single_step() -> None: |
| result = reward_operational( |
| _completions(1), |
| metadata=[ |
| { |
| "operational_signals": [ |
| {"exec_ok": True, "new_info": False, "repeat": False} |
| ] |
| } |
| ], |
| ) |
| assert isinstance(result[0], float) |
|
|
|
|
| def test_operational_batch() -> None: |
| result = reward_operational( |
| _completions(3), |
| metadata=[ |
| {"operational": 1.0}, |
| {"operational": -1.5}, |
| { |
| "operational_signals": [ |
| {"exec_ok": True, "new_info": True, "repeat": False}, |
| ] |
| }, |
| ], |
| ) |
| assert len(result) == 3 |
| assert result == [1.0, -1.5, 2.0] |
|
|
|
|
| def test_operational_trl_compatible() -> None: |
| result = reward_operational(_completions(2), metadata=[{}, {"operational": 0.5}]) |
| assert all(isinstance(item, float) for item in result) |
|
|