| """Integration tests for the GRPO training pipeline flow.""" |
|
|
| from __future__ import annotations |
|
|
| from sql_env.models import SQLObservation |
| from sql_env.training import rollout as rollout_module |
| from sql_env.training.config import GRPOConfig |
| from sql_env.training.rewards import ( |
| reward_correctness, |
| reward_operational, |
| reward_progress, |
| ) |
| from sql_env.training.rollout import rollout_func |
|
|
|
|
| class _Tokenizer: |
| def apply_chat_template( |
| self, |
| messages: list[dict[str, str]], |
| tokenize: bool = False, |
| add_generation_prompt: bool = True, |
| ) -> str: |
| del messages |
| del tokenize |
| del add_generation_prompt |
| return "prompt" |
|
|
|
|
| class _Model: |
| def __init__(self) -> None: |
| self.calls = 0 |
|
|
| def generate(self, prompt: str, max_new_tokens: int) -> str: |
| del prompt |
| del max_new_tokens |
| self.calls += 1 |
| if self.calls == 1: |
| return "hello world random text" |
| return "ANSWER: 42" |
|
|
|
|
| class _Environment: |
| def __init__(self, step_budget: int) -> None: |
| self.step_budget = step_budget |
| self.step_count = 0 |
| self.state = type("State", (), {"episode_id": "ep-integration"})() |
|
|
| def reset(self, *, seed: int | None = None) -> SQLObservation: |
| del seed |
| self.step_count = 0 |
| return self._observation(done=False, result="") |
|
|
| def step(self, action) -> SQLObservation: |
| self.step_count += 1 |
| if ( |
| action.action_type == "QUERY" |
| and action.argument == "hello world random text" |
| ): |
| return self._observation(done=False, result="", error="unparseable action") |
| if action.action_type == "ANSWER": |
| return self._observation( |
| done=True, result="Answer submitted: correct.", reward=1.0 |
| ) |
| return self._observation(done=False, result="ok", reward=0.1) |
|
|
| def _observation( |
| self, |
| *, |
| done: bool, |
| result: str, |
| error: str = "", |
| reward: float | None = 0.0, |
| ) -> SQLObservation: |
| return SQLObservation( |
| question="How many rows?", |
| schema_info="Available tables:\n- t", |
| result=result, |
| error=error, |
| step_count=self.step_count, |
| budget_remaining=max(0, self.step_budget - self.step_count), |
| action_history=[], |
| done=done, |
| reward=reward, |
| ) |
|
|
|
|
| def test_training_pipeline_flow_with_reward_functions(monkeypatch) -> None: |
| """Rollout output can be consumed by all reward callables.""" |
|
|
| config = GRPOConfig( |
| questions_path="data/questions/questions_train.json", |
| db_dir="data/databases", |
| output_dir="outputs/grpo_test", |
| step_budget=3, |
| ) |
| tokenizer = _Tokenizer() |
| model = _Model() |
| fake_env = _Environment(step_budget=3) |
|
|
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| rollouts = rollout_func(["Count rows"], model, tokenizer, config) |
| assert len(rollouts) == 1 |
|
|
| metadata = [item["metadata"] for item in rollouts] |
| completions = [ |
| [{"role": "assistant", "content": item["content"]}] for item in rollouts |
| ] |
|
|
| assert reward_correctness(completions, metadata=metadata) == [1.0] |
| progress = reward_progress(completions, metadata=metadata) |
| operational = reward_operational(completions, metadata=metadata) |
| assert len(progress) == 1 |
| assert 0.0 <= progress[0] <= 1.0 |
| assert len(operational) == 1 |
|
|
|
|
| def test_unparseable_action_recovers_and_episode_continues(monkeypatch) -> None: |
| """Unparseable model output falls back to QUERY and does not abort episode.""" |
|
|
| config = GRPOConfig( |
| questions_path="data/questions/questions_train.json", |
| db_dir="data/databases", |
| output_dir="outputs/grpo_test", |
| step_budget=3, |
| ) |
| tokenizer = _Tokenizer() |
| model = _Model() |
| fake_env = _Environment(step_budget=3) |
|
|
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| rollout = rollout_func(["Count rows"], model, tokenizer, config)[0] |
| assert rollout["metadata"]["step_count"] >= 2 |
| assert rollout["metadata"]["done"] is True |
|
|