#!/usr/bin/env python3 """Verify SFT and GRPO see identical prompt formats. Renders the same question through both pipelines and compares the tokenized output. Run on Colab or any env with transformers installed: python scripts/test_sft_grpo_alignment.py """ from __future__ import annotations import json import sys from pathlib import Path PROJECT_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(PROJECT_ROOT)) from transformers import AutoTokenizer # noqa: E402 from sql_env.training.trl_adapter import get_tool_definitions # noqa: E402 from scripts.generate_sft_data import get_system_prompt # noqa: E402 def render_sft_prompt( tokenizer, messages: list[dict], tools: list[dict], ) -> str: """Render a prompt the way SFT sees it.""" return tokenizer.apply_chat_template( messages, tools=tools, tokenize=False, add_generation_prompt=True, ) def render_grpo_prompt( tokenizer, messages: list[dict], tools: list[dict], ) -> str: """Render a prompt the way GRPO sees it (same template, same tools).""" return tokenizer.apply_chat_template( messages, tools=tools, tokenize=False, add_generation_prompt=True, ) def test_tool_definitions_match_class(): """Verify get_tool_definitions() extracts all SQLEnvTRL methods.""" tools = get_tool_definitions() tool_names = {t["function"]["name"] for t in tools} expected = {"describe", "sample", "query", "answer"} assert tool_names == expected, ( f"Tool mismatch: got {tool_names}, expected {expected}" ) # Each tool should have parameters with required fields for tool in tools: func = tool["function"] assert "name" in func assert "description" in func assert "parameters" in func props = func["parameters"]["properties"] assert len(props) > 0, f"{func['name']} has no parameters" required = func["parameters"]["required"] assert len(required) > 0, f"{func['name']} has no required params" print("[PASS] Tool definitions match SQLEnvTRL methods") return tools def test_prompt_parity(tokenizer, tools): """Verify SFT and GRPO render identical prompts.""" system_prompt = get_system_prompt(enable_thinking=False) question = ( "How many cars have a larger accelerate than the car with " "the largest horsepower?" "Tables: car_makers, car_names, cars_data, continents, " "countries, model_list. " "Use describe, sample, query, and answer tools." ) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": question}, ] sft_rendered = render_sft_prompt(tokenizer, messages, tools) grpo_rendered = render_grpo_prompt(tokenizer, messages, tools) assert sft_rendered == grpo_rendered, ( "SFT and GRPO prompts differ!\n" f"SFT length: {len(sft_rendered)}\n" f"GRPO length: {len(grpo_rendered)}" ) print("[PASS] SFT and GRPO prompts are identical") return sft_rendered def test_tools_in_rendered_prompt(rendered: str, tools: list[dict]): """Verify the rendered prompt contains tool definitions.""" assert "" in rendered, "No block in rendered prompt" assert "" in rendered, "No block in rendered prompt" for tool in tools: name = tool["function"]["name"] assert f'"name": "{name}"' in rendered, ( f"Tool '{name}' not found in rendered prompt" ) print("[PASS] All tool definitions present in rendered prompt") def test_sft_data_has_tools(tools: list[dict]): """Verify SFT data includes tool definitions.""" sft_path = PROJECT_ROOT / "data" / "sft" / "sft_trajectories.json" if not sft_path.exists(): print("[SKIP] SFT data not generated yet") return with open(sft_path) as f: data = json.load(f) has_tools = sum(1 for row in data if "tools" in row) total = len(data) if has_tools == 0: print( f"[WARN] SFT data has NO tool definitions ({total} " "trajectories). Regenerate with: " "python scripts/generate_sft_data.py" ) elif has_tools < total: print(f"[WARN] Only {has_tools}/{total} trajectories have tools") else: # Verify tools match first_tools = data[0]["tools"] first_names = {t["function"]["name"] for t in first_tools} expected_names = {t["function"]["name"] for t in tools} assert first_names == expected_names, ( f"SFT data tools {first_names} != expected {expected_names}" ) print(f"[PASS] All {total} SFT trajectories have matching tools") def test_sft_tool_call_format(tokenizer, tools: list[dict]): """Verify SFT tool_calls render correctly through chat template.""" messages = [ {"role": "system", "content": "You are a SQL assistant."}, {"role": "user", "content": "How many rows in employees?"}, { "role": "assistant", "tool_calls": [ { "type": "function", "function": { "name": "describe", "arguments": {"table_name": "employees"}, }, } ], }, {"role": "tool", "content": "Table 'employees' columns:\n- id"}, { "role": "assistant", "tool_calls": [ { "type": "function", "function": { "name": "query", "arguments": { "sql": "SELECT COUNT(*) FROM employees", }, }, } ], }, {"role": "tool", "content": "1. 42"}, { "role": "assistant", "tool_calls": [ { "type": "function", "function": { "name": "answer", "arguments": {"value": "42"}, }, } ], }, {"role": "tool", "content": "Answer submitted: correct."}, ] rendered = tokenizer.apply_chat_template( messages, tools=tools, tokenize=False, ) # Should contain tool_call tags for each assistant turn tool_call_count = rendered.count("") assert tool_call_count == 3, f"Expected 3 tool_calls, got {tool_call_count}" # Each tool call should have the function name assert '"name": "describe"' in rendered assert '"name": "query"' in rendered assert '"name": "answer"' in rendered # SQL should be present (not null) assert "SELECT COUNT" in rendered print("[PASS] Multi-turn tool_calls render correctly with tools") def main(): model_name = "Qwen/Qwen3-0.6B" print(f"Loading tokenizer: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) print("\n--- Tool Definition Tests ---") tools = test_tool_definitions_match_class() print("\n--- Prompt Parity Tests ---") rendered = test_prompt_parity(tokenizer, tools) print("\n--- Tool Presence Tests ---") test_tools_in_rendered_prompt(rendered, tools) print("\n--- SFT Data Tests ---") test_sft_data_has_tools(tools) print("\n--- Multi-Turn Rendering Tests ---") test_sft_tool_call_format(tokenizer, tools) print("\n--- Rendered Prompt Preview ---") # Show first 600 chars of the rendered prompt print(rendered[:600]) print("...") print("\n=== ALL TESTS PASSED ===") if __name__ == "__main__": main()