"""Integration tests for the GRPO training pipeline flow.""" from __future__ import annotations from sql_env.models import SQLObservation from sql_env.training import rollout as rollout_module from sql_env.training.config import GRPOConfig from sql_env.training.rewards import ( reward_correctness, reward_operational, reward_progress, ) from sql_env.training.rollout import rollout_func class _Tokenizer: def apply_chat_template( self, messages: list[dict[str, str]], tokenize: bool = False, add_generation_prompt: bool = True, ) -> str: del messages del tokenize del add_generation_prompt return "prompt" class _Model: def __init__(self) -> None: self.calls = 0 def generate(self, prompt: str, max_new_tokens: int) -> str: del prompt del max_new_tokens self.calls += 1 if self.calls == 1: return "hello world random text" return "ANSWER: 42" class _Environment: def __init__(self, step_budget: int) -> None: self.step_budget = step_budget self.step_count = 0 self.state = type("State", (), {"episode_id": "ep-integration"})() def reset(self, *, seed: int | None = None) -> SQLObservation: del seed self.step_count = 0 return self._observation(done=False, result="") def step(self, action) -> SQLObservation: self.step_count += 1 if ( action.action_type == "QUERY" and action.argument == "hello world random text" ): return self._observation(done=False, result="", error="unparseable action") if action.action_type == "ANSWER": return self._observation( done=True, result="Answer submitted: correct.", reward=1.0 ) return self._observation(done=False, result="ok", reward=0.1) def _observation( self, *, done: bool, result: str, error: str = "", reward: float | None = 0.0, ) -> SQLObservation: return SQLObservation( question="How many rows?", schema_info="Available tables:\n- t", result=result, error=error, step_count=self.step_count, budget_remaining=max(0, self.step_budget - self.step_count), action_history=[], done=done, reward=reward, ) def test_training_pipeline_flow_with_reward_functions(monkeypatch) -> None: """Rollout output can be consumed by all reward callables.""" config = GRPOConfig( questions_path="data/questions/questions_train.json", db_dir="data/databases", output_dir="outputs/grpo_test", step_budget=3, ) tokenizer = _Tokenizer() model = _Model() fake_env = _Environment(step_budget=3) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) rollouts = rollout_func(["Count rows"], model, tokenizer, config) assert len(rollouts) == 1 metadata = [item["metadata"] for item in rollouts] completions = [ [{"role": "assistant", "content": item["content"]}] for item in rollouts ] assert reward_correctness(completions, metadata=metadata) == [1.0] progress = reward_progress(completions, metadata=metadata) operational = reward_operational(completions, metadata=metadata) assert len(progress) == 1 assert 0.0 <= progress[0] <= 1.0 assert len(operational) == 1 def test_unparseable_action_recovers_and_episode_continues(monkeypatch) -> None: """Unparseable model output falls back to QUERY and does not abort episode.""" config = GRPOConfig( questions_path="data/questions/questions_train.json", db_dir="data/databases", output_dir="outputs/grpo_test", step_budget=3, ) tokenizer = _Tokenizer() model = _Model() fake_env = _Environment(step_budget=3) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) rollout = rollout_func(["Count rows"], model, tokenizer, config)[0] assert rollout["metadata"]["step_count"] >= 2 assert rollout["metadata"]["done"] is True