"""Unit tests for GRPO training configuration.""" import pytest from sql_env.training import GRPOConfig def test_grpo_config_custom_values() -> None: config = GRPOConfig( questions_path="custom.json", db_dir="custom_db/", output_dir="custom_out/", model_name="gpt2", max_new_tokens=128, num_train_epochs=3, per_device_train_batch_size=8, gradient_accumulation_steps=2, learning_rate=1e-5, num_generations=2, step_budget=5, difficulty_filter=["easy"], seed=7, logging_steps=5, ) assert config.model_name == "gpt2" assert config.max_new_tokens == 128 assert config.num_train_epochs == 3 assert config.per_device_train_batch_size == 8 assert config.gradient_accumulation_steps == 2 assert config.learning_rate == 1e-5 assert config.num_generations == 2 assert config.step_budget == 5 assert config.difficulty_filter == ["easy"] assert config.seed == 7 assert config.logging_steps == 5 def test_grpo_config_required_fields() -> None: with pytest.raises(TypeError): GRPOConfig() # type: ignore[call-arg] def test_grpo_config_negative_batch_size() -> None: with pytest.raises(ValueError, match="per_device_train_batch_size"): GRPOConfig( questions_path="q.json", db_dir="dbs/", output_dir="out/", per_device_train_batch_size=0, ) def test_grpo_config_negative_learning_rate() -> None: with pytest.raises(ValueError, match="learning_rate"): GRPOConfig( questions_path="q.json", db_dir="dbs/", output_dir="out/", learning_rate=-1.0, ) def test_grpo_config_empty_difficulty_filter() -> None: with pytest.raises(ValueError, match="difficulty_filter"): GRPOConfig( questions_path="q.json", db_dir="dbs/", output_dir="out/", difficulty_filter=[], )