sql_env / tests /integration /test_training_pipeline.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Integration tests for the GRPO training pipeline flow."""
from __future__ import annotations
from sql_env.models import SQLObservation
from sql_env.training import rollout as rollout_module
from sql_env.training.config import GRPOConfig
from sql_env.training.rewards import (
reward_correctness,
reward_operational,
reward_progress,
)
from sql_env.training.rollout import rollout_func
class _Tokenizer:
def apply_chat_template(
self,
messages: list[dict[str, str]],
tokenize: bool = False,
add_generation_prompt: bool = True,
) -> str:
del messages
del tokenize
del add_generation_prompt
return "prompt"
class _Model:
def __init__(self) -> None:
self.calls = 0
def generate(self, prompt: str, max_new_tokens: int) -> str:
del prompt
del max_new_tokens
self.calls += 1
if self.calls == 1:
return "hello world random text"
return "ANSWER: 42"
class _Environment:
def __init__(self, step_budget: int) -> None:
self.step_budget = step_budget
self.step_count = 0
self.state = type("State", (), {"episode_id": "ep-integration"})()
def reset(self, *, seed: int | None = None) -> SQLObservation:
del seed
self.step_count = 0
return self._observation(done=False, result="")
def step(self, action) -> SQLObservation:
self.step_count += 1
if (
action.action_type == "QUERY"
and action.argument == "hello world random text"
):
return self._observation(done=False, result="", error="unparseable action")
if action.action_type == "ANSWER":
return self._observation(
done=True, result="Answer submitted: correct.", reward=1.0
)
return self._observation(done=False, result="ok", reward=0.1)
def _observation(
self,
*,
done: bool,
result: str,
error: str = "",
reward: float | None = 0.0,
) -> SQLObservation:
return SQLObservation(
question="How many rows?",
schema_info="Available tables:\n- t",
result=result,
error=error,
step_count=self.step_count,
budget_remaining=max(0, self.step_budget - self.step_count),
action_history=[],
done=done,
reward=reward,
)
def test_training_pipeline_flow_with_reward_functions(monkeypatch) -> None:
"""Rollout output can be consumed by all reward callables."""
config = GRPOConfig(
questions_path="data/questions/questions_train.json",
db_dir="data/databases",
output_dir="outputs/grpo_test",
step_budget=3,
)
tokenizer = _Tokenizer()
model = _Model()
fake_env = _Environment(step_budget=3)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
rollouts = rollout_func(["Count rows"], model, tokenizer, config)
assert len(rollouts) == 1
metadata = [item["metadata"] for item in rollouts]
completions = [
[{"role": "assistant", "content": item["content"]}] for item in rollouts
]
assert reward_correctness(completions, metadata=metadata) == [1.0]
progress = reward_progress(completions, metadata=metadata)
operational = reward_operational(completions, metadata=metadata)
assert len(progress) == 1
assert 0.0 <= progress[0] <= 1.0
assert len(operational) == 1
def test_unparseable_action_recovers_and_episode_continues(monkeypatch) -> None:
"""Unparseable model output falls back to QUERY and does not abort episode."""
config = GRPOConfig(
questions_path="data/questions/questions_train.json",
db_dir="data/databases",
output_dir="outputs/grpo_test",
step_budget=3,
)
tokenizer = _Tokenizer()
model = _Model()
fake_env = _Environment(step_budget=3)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
rollout = rollout_func(["Count rows"], model, tokenizer, config)[0]
assert rollout["metadata"]["step_count"] >= 2
assert rollout["metadata"]["done"] is True