sql_env / tests /unit /test_grpo_config.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""Unit tests for GRPO training configuration."""
import pytest
from sql_env.training import GRPOConfig
def test_grpo_config_custom_values() -> None:
config = GRPOConfig(
questions_path="custom.json",
db_dir="custom_db/",
output_dir="custom_out/",
model_name="gpt2",
max_new_tokens=128,
num_train_epochs=3,
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
learning_rate=1e-5,
num_generations=2,
step_budget=5,
difficulty_filter=["easy"],
seed=7,
logging_steps=5,
)
assert config.model_name == "gpt2"
assert config.max_new_tokens == 128
assert config.num_train_epochs == 3
assert config.per_device_train_batch_size == 8
assert config.gradient_accumulation_steps == 2
assert config.learning_rate == 1e-5
assert config.num_generations == 2
assert config.step_budget == 5
assert config.difficulty_filter == ["easy"]
assert config.seed == 7
assert config.logging_steps == 5
def test_grpo_config_required_fields() -> None:
with pytest.raises(TypeError):
GRPOConfig() # type: ignore[call-arg]
def test_grpo_config_negative_batch_size() -> None:
with pytest.raises(ValueError, match="per_device_train_batch_size"):
GRPOConfig(
questions_path="q.json",
db_dir="dbs/",
output_dir="out/",
per_device_train_batch_size=0,
)
def test_grpo_config_negative_learning_rate() -> None:
with pytest.raises(ValueError, match="learning_rate"):
GRPOConfig(
questions_path="q.json",
db_dir="dbs/",
output_dir="out/",
learning_rate=-1.0,
)
def test_grpo_config_empty_difficulty_filter() -> None:
with pytest.raises(ValueError, match="difficulty_filter"):
GRPOConfig(
questions_path="q.json",
db_dir="dbs/",
output_dir="out/",
difficulty_filter=[],
)