| |
| """Verify SFT and GRPO see identical prompt formats. |
| |
| Renders the same question through both pipelines and compares the |
| tokenized output. Run on Colab or any env with transformers installed: |
| |
| python scripts/test_sft_grpo_alignment.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import sys |
| from pathlib import Path |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from transformers import AutoTokenizer |
|
|
| from sql_env.training.trl_adapter import get_tool_definitions |
|
|
| from scripts.generate_sft_data import get_system_prompt |
|
|
|
|
| def render_sft_prompt( |
| tokenizer, |
| messages: list[dict], |
| tools: list[dict], |
| ) -> str: |
| """Render a prompt the way SFT sees it.""" |
| return tokenizer.apply_chat_template( |
| messages, |
| tools=tools, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
|
|
| def render_grpo_prompt( |
| tokenizer, |
| messages: list[dict], |
| tools: list[dict], |
| ) -> str: |
| """Render a prompt the way GRPO sees it (same template, same tools).""" |
| return tokenizer.apply_chat_template( |
| messages, |
| tools=tools, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
|
|
| def test_tool_definitions_match_class(): |
| """Verify get_tool_definitions() extracts all SQLEnvTRL methods.""" |
| tools = get_tool_definitions() |
| tool_names = {t["function"]["name"] for t in tools} |
|
|
| expected = {"describe", "sample", "query", "answer"} |
| assert tool_names == expected, ( |
| f"Tool mismatch: got {tool_names}, expected {expected}" |
| ) |
|
|
| |
| for tool in tools: |
| func = tool["function"] |
| assert "name" in func |
| assert "description" in func |
| assert "parameters" in func |
| props = func["parameters"]["properties"] |
| assert len(props) > 0, f"{func['name']} has no parameters" |
| required = func["parameters"]["required"] |
| assert len(required) > 0, f"{func['name']} has no required params" |
|
|
| print("[PASS] Tool definitions match SQLEnvTRL methods") |
| return tools |
|
|
|
|
| def test_prompt_parity(tokenizer, tools): |
| """Verify SFT and GRPO render identical prompts.""" |
| system_prompt = get_system_prompt(enable_thinking=False) |
| question = ( |
| "How many cars have a larger accelerate than the car with " |
| "the largest horsepower?" |
| "Tables: car_makers, car_names, cars_data, continents, " |
| "countries, model_list. " |
| "Use describe, sample, query, and answer tools." |
| ) |
|
|
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": question}, |
| ] |
|
|
| sft_rendered = render_sft_prompt(tokenizer, messages, tools) |
| grpo_rendered = render_grpo_prompt(tokenizer, messages, tools) |
|
|
| assert sft_rendered == grpo_rendered, ( |
| "SFT and GRPO prompts differ!\n" |
| f"SFT length: {len(sft_rendered)}\n" |
| f"GRPO length: {len(grpo_rendered)}" |
| ) |
| print("[PASS] SFT and GRPO prompts are identical") |
| return sft_rendered |
|
|
|
|
| def test_tools_in_rendered_prompt(rendered: str, tools: list[dict]): |
| """Verify the rendered prompt contains tool definitions.""" |
| assert "<tools>" in rendered, "No <tools> block in rendered prompt" |
| assert "</tools>" in rendered, "No </tools> block in rendered prompt" |
|
|
| for tool in tools: |
| name = tool["function"]["name"] |
| assert f'"name": "{name}"' in rendered, ( |
| f"Tool '{name}' not found in rendered prompt" |
| ) |
|
|
| print("[PASS] All tool definitions present in rendered prompt") |
|
|
|
|
| def test_sft_data_has_tools(tools: list[dict]): |
| """Verify SFT data includes tool definitions.""" |
| sft_path = PROJECT_ROOT / "data" / "sft" / "sft_trajectories.json" |
| if not sft_path.exists(): |
| print("[SKIP] SFT data not generated yet") |
| return |
|
|
| with open(sft_path) as f: |
| data = json.load(f) |
|
|
| has_tools = sum(1 for row in data if "tools" in row) |
| total = len(data) |
|
|
| if has_tools == 0: |
| print( |
| f"[WARN] SFT data has NO tool definitions ({total} " |
| "trajectories). Regenerate with: " |
| "python scripts/generate_sft_data.py" |
| ) |
| elif has_tools < total: |
| print(f"[WARN] Only {has_tools}/{total} trajectories have tools") |
| else: |
| |
| first_tools = data[0]["tools"] |
| first_names = {t["function"]["name"] for t in first_tools} |
| expected_names = {t["function"]["name"] for t in tools} |
| assert first_names == expected_names, ( |
| f"SFT data tools {first_names} != expected {expected_names}" |
| ) |
| print(f"[PASS] All {total} SFT trajectories have matching tools") |
|
|
|
|
| def test_sft_tool_call_format(tokenizer, tools: list[dict]): |
| """Verify SFT tool_calls render correctly through chat template.""" |
| messages = [ |
| {"role": "system", "content": "You are a SQL assistant."}, |
| {"role": "user", "content": "How many rows in employees?"}, |
| { |
| "role": "assistant", |
| "tool_calls": [ |
| { |
| "type": "function", |
| "function": { |
| "name": "describe", |
| "arguments": {"table_name": "employees"}, |
| }, |
| } |
| ], |
| }, |
| {"role": "tool", "content": "Table 'employees' columns:\n- id"}, |
| { |
| "role": "assistant", |
| "tool_calls": [ |
| { |
| "type": "function", |
| "function": { |
| "name": "query", |
| "arguments": { |
| "sql": "SELECT COUNT(*) FROM employees", |
| }, |
| }, |
| } |
| ], |
| }, |
| {"role": "tool", "content": "1. 42"}, |
| { |
| "role": "assistant", |
| "tool_calls": [ |
| { |
| "type": "function", |
| "function": { |
| "name": "answer", |
| "arguments": {"value": "42"}, |
| }, |
| } |
| ], |
| }, |
| {"role": "tool", "content": "Answer submitted: correct."}, |
| ] |
|
|
| rendered = tokenizer.apply_chat_template( |
| messages, |
| tools=tools, |
| tokenize=False, |
| ) |
|
|
| |
| tool_call_count = rendered.count("<tool_call>") |
| assert tool_call_count == 3, f"Expected 3 tool_calls, got {tool_call_count}" |
|
|
| |
| assert '"name": "describe"' in rendered |
| assert '"name": "query"' in rendered |
| assert '"name": "answer"' in rendered |
|
|
| |
| assert "SELECT COUNT" in rendered |
|
|
| print("[PASS] Multi-turn tool_calls render correctly with tools") |
|
|
|
|
| def main(): |
| model_name = "Qwen/Qwen3-0.6B" |
| print(f"Loading tokenizer: {model_name}") |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
| print("\n--- Tool Definition Tests ---") |
| tools = test_tool_definitions_match_class() |
|
|
| print("\n--- Prompt Parity Tests ---") |
| rendered = test_prompt_parity(tokenizer, tools) |
|
|
| print("\n--- Tool Presence Tests ---") |
| test_tools_in_rendered_prompt(rendered, tools) |
|
|
| print("\n--- SFT Data Tests ---") |
| test_sft_data_has_tools(tools) |
|
|
| print("\n--- Multi-Turn Rendering Tests ---") |
| test_sft_tool_call_format(tokenizer, tools) |
|
|
| print("\n--- Rendered Prompt Preview ---") |
| |
| print(rendered[:600]) |
| print("...") |
|
|
| print("\n=== ALL TESTS PASSED ===") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|