"""Unit tests for training reward callables.""" from sql_env.training.rewards import ( reward_correctness, reward_operational, reward_progress, ) def _completions(size: int) -> list[list[dict[str, str]]]: return [[{"role": "assistant", "content": "QUERY: SELECT 1"}] for _ in range(size)] def test_correctness_correct_answer() -> None: result = reward_correctness(_completions(1), metadata=[{"correct": True}]) assert result == [1.0] def test_correctness_wrong_answer() -> None: result = reward_correctness(_completions(1), metadata=[{"correct": False}]) assert result == [0.0] def test_correctness_no_answer() -> None: result = reward_correctness(_completions(1), metadata=[{}]) assert result == [0.0] def test_correctness_batch() -> None: result = reward_correctness( _completions(4), metadata=[ {"answer_correct": True}, {"answer_correct": False}, {"correct": True}, {"correct": False}, ], ) assert result == [1.0, 0.0, 1.0, 0.0] def test_correctness_empty_batch() -> None: result = reward_correctness([]) assert result == [] def test_correctness_trl_compatible() -> None: result = reward_correctness(_completions(2), metadata=[{"correct": True}, {}]) assert all(isinstance(item, float) for item in result) def test_progress_full() -> None: result = reward_progress(_completions(1), metadata=[{"progress": 1.0}]) assert result[0] == 1.0 def test_progress_none() -> None: result = reward_progress(_completions(1), metadata=[{"progress": 0.0}]) assert result == [0.0] def test_progress_partial() -> None: result = reward_progress(_completions(1), metadata=[{"cumulative_progress": 0.4}]) assert 0.0 < result[0] < 1.0 def test_progress_normalized() -> None: result = reward_progress( _completions(4), metadata=[ {"progress": -1.0}, {"progress": 0.2}, {"progress": 2.0}, {}, ], ) assert all(0.0 <= item <= 1.0 for item in result) def test_progress_batch() -> None: result = reward_progress( _completions(3), metadata=[{"progress": 0.0}, {"progress": 0.5}, {"progress": 1.0}], ) assert result == [0.0, 0.5, 1.0] def test_progress_trl_compatible() -> None: result = reward_progress(_completions(2), metadata=[{}, {"progress": 0.1}]) assert all(isinstance(item, float) for item in result) def test_operational_good_episode() -> None: result = reward_operational( _completions(1), metadata=[ { "operational_signals": [ {"exec_ok": True, "new_info": True, "repeat": False}, {"exec_ok": True, "new_info": False, "repeat": False}, ] } ], ) assert result[0] > 0.0 def test_operational_all_errors() -> None: result = reward_operational( _completions(1), metadata=[ { "operational_signals": [ {"exec_ok": False, "new_info": False, "repeat": False}, {"exec_ok": False, "new_info": False, "repeat": False}, ] } ], ) assert result[0] <= 0.0 def test_operational_repeat_penalty() -> None: non_repeating = reward_operational( _completions(1), metadata=[ { "operational_signals": [ {"exec_ok": True, "new_info": False, "repeat": False}, {"exec_ok": True, "new_info": False, "repeat": False}, ] } ], ) repeating = reward_operational( _completions(1), metadata=[ { "operational_signals": [ {"exec_ok": True, "new_info": False, "repeat": True}, {"exec_ok": True, "new_info": False, "repeat": True}, ] } ], ) assert repeating[0] < non_repeating[0] def test_operational_mixed_signals() -> None: result = reward_operational( _completions(1), metadata=[ { "operational_signals": [ {"exec_ok": True, "new_info": True, "repeat": False}, {"exec_ok": False, "new_info": False, "repeat": False}, {"exec_ok": True, "new_info": False, "repeat": True}, ] } ], ) assert 0.0 < result[0] < 4.0 def test_operational_single_step() -> None: result = reward_operational( _completions(1), metadata=[ { "operational_signals": [ {"exec_ok": True, "new_info": False, "repeat": False} ] } ], ) assert isinstance(result[0], float) def test_operational_batch() -> None: result = reward_operational( _completions(3), metadata=[ {"operational": 1.0}, {"operational": -1.5}, { "operational_signals": [ {"exec_ok": True, "new_info": True, "repeat": False}, ] }, ], ) assert len(result) == 3 assert result == [1.0, -1.5, 2.0] def test_operational_trl_compatible() -> None: result = reward_operational(_completions(2), metadata=[{}, {"operational": 0.5}]) assert all(isinstance(item, float) for item in result)