sql_env / scripts /test_sft_grpo_alignment.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
#!/usr/bin/env python3
"""Verify SFT and GRPO see identical prompt formats.
Renders the same question through both pipelines and compares the
tokenized output. Run on Colab or any env with transformers installed:
python scripts/test_sft_grpo_alignment.py
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from transformers import AutoTokenizer # noqa: E402
from sql_env.training.trl_adapter import get_tool_definitions # noqa: E402
from scripts.generate_sft_data import get_system_prompt # noqa: E402
def render_sft_prompt(
tokenizer,
messages: list[dict],
tools: list[dict],
) -> str:
"""Render a prompt the way SFT sees it."""
return tokenizer.apply_chat_template(
messages,
tools=tools,
tokenize=False,
add_generation_prompt=True,
)
def render_grpo_prompt(
tokenizer,
messages: list[dict],
tools: list[dict],
) -> str:
"""Render a prompt the way GRPO sees it (same template, same tools)."""
return tokenizer.apply_chat_template(
messages,
tools=tools,
tokenize=False,
add_generation_prompt=True,
)
def test_tool_definitions_match_class():
"""Verify get_tool_definitions() extracts all SQLEnvTRL methods."""
tools = get_tool_definitions()
tool_names = {t["function"]["name"] for t in tools}
expected = {"describe", "sample", "query", "answer"}
assert tool_names == expected, (
f"Tool mismatch: got {tool_names}, expected {expected}"
)
# Each tool should have parameters with required fields
for tool in tools:
func = tool["function"]
assert "name" in func
assert "description" in func
assert "parameters" in func
props = func["parameters"]["properties"]
assert len(props) > 0, f"{func['name']} has no parameters"
required = func["parameters"]["required"]
assert len(required) > 0, f"{func['name']} has no required params"
print("[PASS] Tool definitions match SQLEnvTRL methods")
return tools
def test_prompt_parity(tokenizer, tools):
"""Verify SFT and GRPO render identical prompts."""
system_prompt = get_system_prompt(enable_thinking=False)
question = (
"How many cars have a larger accelerate than the car with "
"the largest horsepower?"
"Tables: car_makers, car_names, cars_data, continents, "
"countries, model_list. "
"Use describe, sample, query, and answer tools."
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
]
sft_rendered = render_sft_prompt(tokenizer, messages, tools)
grpo_rendered = render_grpo_prompt(tokenizer, messages, tools)
assert sft_rendered == grpo_rendered, (
"SFT and GRPO prompts differ!\n"
f"SFT length: {len(sft_rendered)}\n"
f"GRPO length: {len(grpo_rendered)}"
)
print("[PASS] SFT and GRPO prompts are identical")
return sft_rendered
def test_tools_in_rendered_prompt(rendered: str, tools: list[dict]):
"""Verify the rendered prompt contains tool definitions."""
assert "<tools>" in rendered, "No <tools> block in rendered prompt"
assert "</tools>" in rendered, "No </tools> block in rendered prompt"
for tool in tools:
name = tool["function"]["name"]
assert f'"name": "{name}"' in rendered, (
f"Tool '{name}' not found in rendered prompt"
)
print("[PASS] All tool definitions present in rendered prompt")
def test_sft_data_has_tools(tools: list[dict]):
"""Verify SFT data includes tool definitions."""
sft_path = PROJECT_ROOT / "data" / "sft" / "sft_trajectories.json"
if not sft_path.exists():
print("[SKIP] SFT data not generated yet")
return
with open(sft_path) as f:
data = json.load(f)
has_tools = sum(1 for row in data if "tools" in row)
total = len(data)
if has_tools == 0:
print(
f"[WARN] SFT data has NO tool definitions ({total} "
"trajectories). Regenerate with: "
"python scripts/generate_sft_data.py"
)
elif has_tools < total:
print(f"[WARN] Only {has_tools}/{total} trajectories have tools")
else:
# Verify tools match
first_tools = data[0]["tools"]
first_names = {t["function"]["name"] for t in first_tools}
expected_names = {t["function"]["name"] for t in tools}
assert first_names == expected_names, (
f"SFT data tools {first_names} != expected {expected_names}"
)
print(f"[PASS] All {total} SFT trajectories have matching tools")
def test_sft_tool_call_format(tokenizer, tools: list[dict]):
"""Verify SFT tool_calls render correctly through chat template."""
messages = [
{"role": "system", "content": "You are a SQL assistant."},
{"role": "user", "content": "How many rows in employees?"},
{
"role": "assistant",
"tool_calls": [
{
"type": "function",
"function": {
"name": "describe",
"arguments": {"table_name": "employees"},
},
}
],
},
{"role": "tool", "content": "Table 'employees' columns:\n- id"},
{
"role": "assistant",
"tool_calls": [
{
"type": "function",
"function": {
"name": "query",
"arguments": {
"sql": "SELECT COUNT(*) FROM employees",
},
},
}
],
},
{"role": "tool", "content": "1. 42"},
{
"role": "assistant",
"tool_calls": [
{
"type": "function",
"function": {
"name": "answer",
"arguments": {"value": "42"},
},
}
],
},
{"role": "tool", "content": "Answer submitted: correct."},
]
rendered = tokenizer.apply_chat_template(
messages,
tools=tools,
tokenize=False,
)
# Should contain tool_call tags for each assistant turn
tool_call_count = rendered.count("<tool_call>")
assert tool_call_count == 3, f"Expected 3 tool_calls, got {tool_call_count}"
# Each tool call should have the function name
assert '"name": "describe"' in rendered
assert '"name": "query"' in rendered
assert '"name": "answer"' in rendered
# SQL should be present (not null)
assert "SELECT COUNT" in rendered
print("[PASS] Multi-turn tool_calls render correctly with tools")
def main():
model_name = "Qwen/Qwen3-0.6B"
print(f"Loading tokenizer: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("\n--- Tool Definition Tests ---")
tools = test_tool_definitions_match_class()
print("\n--- Prompt Parity Tests ---")
rendered = test_prompt_parity(tokenizer, tools)
print("\n--- Tool Presence Tests ---")
test_tools_in_rendered_prompt(rendered, tools)
print("\n--- SFT Data Tests ---")
test_sft_data_has_tools(tools)
print("\n--- Multi-Turn Rendering Tests ---")
test_sft_tool_call_format(tokenizer, tools)
print("\n--- Rendered Prompt Preview ---")
# Show first 600 chars of the rendered prompt
print(rendered[:600])
print("...")
print("\n=== ALL TESTS PASSED ===")
if __name__ == "__main__":
main()