| """Unit tests for GRPO training configuration.""" |
|
|
| import pytest |
|
|
| from sql_env.training import GRPOConfig |
|
|
|
|
| def test_grpo_config_custom_values() -> None: |
| config = GRPOConfig( |
| questions_path="custom.json", |
| db_dir="custom_db/", |
| output_dir="custom_out/", |
| model_name="gpt2", |
| max_new_tokens=128, |
| num_train_epochs=3, |
| per_device_train_batch_size=8, |
| gradient_accumulation_steps=2, |
| learning_rate=1e-5, |
| num_generations=2, |
| step_budget=5, |
| difficulty_filter=["easy"], |
| seed=7, |
| logging_steps=5, |
| ) |
|
|
| assert config.model_name == "gpt2" |
| assert config.max_new_tokens == 128 |
| assert config.num_train_epochs == 3 |
| assert config.per_device_train_batch_size == 8 |
| assert config.gradient_accumulation_steps == 2 |
| assert config.learning_rate == 1e-5 |
| assert config.num_generations == 2 |
| assert config.step_budget == 5 |
| assert config.difficulty_filter == ["easy"] |
| assert config.seed == 7 |
| assert config.logging_steps == 5 |
|
|
|
|
| def test_grpo_config_required_fields() -> None: |
| with pytest.raises(TypeError): |
| GRPOConfig() |
|
|
|
|
| def test_grpo_config_negative_batch_size() -> None: |
| with pytest.raises(ValueError, match="per_device_train_batch_size"): |
| GRPOConfig( |
| questions_path="q.json", |
| db_dir="dbs/", |
| output_dir="out/", |
| per_device_train_batch_size=0, |
| ) |
|
|
|
|
| def test_grpo_config_negative_learning_rate() -> None: |
| with pytest.raises(ValueError, match="learning_rate"): |
| GRPOConfig( |
| questions_path="q.json", |
| db_dir="dbs/", |
| output_dir="out/", |
| learning_rate=-1.0, |
| ) |
|
|
|
|
| def test_grpo_config_empty_difficulty_filter() -> None: |
| with pytest.raises(ValueError, match="difficulty_filter"): |
| GRPOConfig( |
| questions_path="q.json", |
| db_dir="dbs/", |
| output_dir="out/", |
| difficulty_filter=[], |
| ) |
|
|