#!/usr/bin/env python3 """Generate SFT warmup trajectories from gold SQL answers. Runs the SQLEnvironment programmatically for each training question: 1. describe() each table in tables_involved 2. query() with the gold SQL 3. answer() with the gold answer Captures real environment responses so the model learns what describe output looks like and how to read query results. Usage: uv run python scripts/generate_sft_data.py """ from __future__ import annotations import json import sys from pathlib import Path # Ensure project root is importable PROJECT_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(PROJECT_ROOT)) from sql_env.models import SQLAction # noqa: E402 from sql_env.server.sql_environment import SQLEnvironment # noqa: E402 from sql_env.server.mock_tokenizer import MockTokenizer # noqa: E402 from sql_env.training.prompts import get_system_prompt # noqa: E402 from sql_env.training.trl_adapter import get_tool_definitions # noqa: E402 def _format_answer_for_model(gold_answer: str, answer_type: str) -> str: """Format gold answer as human-readable text the model should produce. Converts Python literals (lists, list-of-lists) into the plain-text formats that the verifier can parse: comma-separated for lists, pipe-separated rows for tables. """ import ast raw = str(gold_answer).strip() if answer_type == "list" and raw.startswith("["): try: parsed = ast.literal_eval(raw) if isinstance(parsed, list): return ", ".join(str(item) for item in parsed) except (ValueError, SyntaxError): pass if answer_type == "table" and raw.startswith("["): try: parsed = ast.literal_eval(raw) if ( isinstance(parsed, list) and parsed and isinstance(parsed[0], (list, tuple)) ): lines = [] for row in parsed: lines.append(" | ".join(str(cell) for cell in row)) return "\n".join(lines) except (ValueError, SyntaxError): pass return raw def _make_wrong_query( gold_sql: str, table_columns: dict[str, list[str]], ) -> tuple[str, str]: """Mutate SQL to provoke a realistic execution error.""" import random import re if not table_columns: return gold_sql, "none" sql = gold_sql table_names = list(table_columns.keys()) all_columns = { column.lower() for columns in table_columns.values() for column in columns } table_pattern = re.compile( r"\b(" + "|".join(re.escape(name) for name in table_names) + r")\b", flags=re.IGNORECASE, ) table_matches = list(table_pattern.finditer(sql)) column_candidates: list[str] = [] for columns in table_columns.values(): for column in columns: if re.search(rf"\b{re.escape(column)}\b", sql, flags=re.IGNORECASE): column_candidates.append(column) strategies: list[str] = [] if column_candidates: strategies.append("wrong_column") if table_matches: strategies.append("wrong_table") if re.search(r"\bJOIN\b", sql, flags=re.IGNORECASE): strategies.append("missing_join") if not strategies: return gold_sql, "none" mutation = random.choice(strategies) if mutation == "wrong_column": candidate = random.choice(column_candidates) replacement = f"{candidate}_old" while replacement.lower() in all_columns: replacement += "_x" wrong_sql = re.sub( rf"\b{re.escape(candidate)}\b", replacement, sql, count=1, flags=re.IGNORECASE, ) return ( (wrong_sql, "wrong_column") if wrong_sql != gold_sql else (gold_sql, "none") ) if mutation == "wrong_table": chosen = random.choice(table_matches) table_name = chosen.group(0) replacement = f"{table_name}_v2" while replacement.lower() in {name.lower() for name in table_names}: replacement += "_x" wrong_sql = re.sub( rf"\b{re.escape(table_name)}\b", replacement, sql, count=1, flags=re.IGNORECASE, ) return ( (wrong_sql, "wrong_table") if wrong_sql != gold_sql else (gold_sql, "none") ) join_clause = re.search( r"\bJOIN\b\s+[^\s]+(?:\s+\w+)?\s+\bON\b\s+" r"[^;]+?(?=\bJOIN\b|\bWHERE\b|\bGROUP\b|\bORDER\b|\bLIMIT\b|$)", sql, flags=re.IGNORECASE, ) if join_clause: wrong_sql = sql[: join_clause.start()] + sql[join_clause.end() :] return ( (wrong_sql, "missing_join") if wrong_sql != gold_sql else (gold_sql, "none") ) return gold_sql, "none" # Canonical prompt is defined in training/prompts.py — single source of truth. SYSTEM_PROMPT = get_system_prompt(enable_thinking=False) # Extract tool definitions dynamically from SQLEnvTRL — guarantees # SFT sees the same schema that TRL generates for GRPO. TOOL_DEFINITIONS = get_tool_definitions() def generate_trajectory( env: SQLEnvironment, question: dict, ) -> dict | None: """Generate a full multi-turn SFT example from one question's gold trajectory. Returns a single example with ``messages`` containing the complete conversation: system, user, then alternating assistant tool_calls and tool responses for describe/query/answer. With ``assistant_only_loss`` enabled in TRL, loss is computed on all assistant turns (not tool responses), so the model learns the full describe→query→answer workflow in one forward pass. NOTE: arguments must be JSON strings, not dicts. Qwen3's apply_chat_template expands dicts to include all parameter names from all tools with null values. """ matching = [ q for q in env.questions if q.question_text == question["question_text"] ] if not matching: return None original_questions = list(env.questions) try: env.questions = matching obs = env.reset(seed=None) finally: env.questions = original_questions # Build table hint (matches what TRL adapter returns from reset) tables_from_schema = [] for line in (obs.schema_info or "").split("\n"): stripped = line.strip().lstrip("- ").strip() if stripped and stripped != "Available tables:": tables_from_schema.append(stripped) table_hint = ( f"Tables: {', '.join(tables_from_schema)}. " "Use describe, sample, query, and answer tools." ) user_content = f"{question['question_text']}" + table_hint tools = TOOL_DEFINITIONS messages: list[dict] = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_content}, ] # Step 1: describe each table involved tables = question.get("tables_involved", []) for table in tables: assistant_msg = { "role": "assistant", "tool_calls": [ { "type": "function", "function": { "name": "describe", "arguments": json.dumps({"table_name": table}), }, } ], } messages.append(assistant_msg) obs = env.step(SQLAction(action_type="DESCRIBE", argument=table)) if obs.error: return None messages.append({"role": "tool", "content": obs.result}) # Step 2: query with gold SQL gold_sql = question["gold_sql"] query_msg = { "role": "assistant", "tool_calls": [ { "type": "function", "function": { "name": "query", "arguments": json.dumps({"sql": gold_sql}), }, } ], } messages.append(query_msg) obs = env.step(SQLAction(action_type="QUERY", argument=gold_sql)) if obs.error: return None messages.append({"role": "tool", "content": obs.result}) # Step 3: submit answer answer_type = question.get("answer_type", "string") gold_answer = _format_answer_for_model(str(question["gold_answer"]), answer_type) answer_msg = { "role": "assistant", "tool_calls": [ { "type": "function", "function": { "name": "answer", "arguments": json.dumps({"value": gold_answer}), }, } ], } messages.append(answer_msg) obs = env.step(SQLAction(action_type="ANSWER", argument=gold_answer)) messages.append({"role": "tool", "content": obs.result or ""}) messages.append({"role": "assistant", "content": "Task complete."}) # Only keep if trajectory got correct if obs.reward is None or obs.reward < 0.5: return None return {"messages": messages, "tools": tools} def generate_error_recovery_trajectory( env: SQLEnvironment, question: dict, ) -> dict | None: """Generate a trajectory that demonstrates SQL error recovery.""" matching = [ q for q in env.questions if q.question_text == question["question_text"] ] if not matching: return None original_questions = list(env.questions) try: env.questions = matching obs = env.reset(seed=None) finally: env.questions = original_questions tables_from_schema = [] for line in (obs.schema_info or "").split("\n"): stripped = line.strip().lstrip("- ").strip() if stripped and stripped != "Available tables:": tables_from_schema.append(stripped) table_hint = ( f"Tables: {', '.join(tables_from_schema)}. " "Use describe, sample, query, and answer tools." ) user_content = f"{question['question_text']}" + table_hint messages: list[dict] = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_content}, ] tables = question.get("tables_involved", []) table_columns: dict[str, list[str]] = {} for table in tables: describe_msg = { "role": "assistant", "tool_calls": [ { "type": "function", "function": { "name": "describe", "arguments": json.dumps({"table_name": table}), }, } ], } messages.append(describe_msg) obs = env.step(SQLAction(action_type="DESCRIBE", argument=table)) if obs.error: return None describe_result = obs.result or "" messages.append({"role": "tool", "content": describe_result}) parsed_columns: list[str] = [] for row in describe_result.splitlines(): stripped = row.strip() if not stripped.startswith("- ") or ":" not in stripped: continue name = stripped[2:].split(":", 1)[0].strip() if name: parsed_columns.append(name) if parsed_columns: table_columns[table] = parsed_columns gold_sql = question["gold_sql"] wrong_sql, mutation_type = _make_wrong_query(gold_sql, table_columns) if mutation_type == "none": return None wrong_query_msg = { "role": "assistant", "tool_calls": [ { "type": "function", "function": { "name": "query", "arguments": json.dumps({"sql": wrong_sql}), }, } ], } messages.append(wrong_query_msg) obs = env.step(SQLAction(action_type="QUERY", argument=wrong_sql)) if not obs.error: return None messages.append({"role": "tool", "content": obs.error}) recovery_table = tables[0] if tables else next(iter(table_columns), None) if recovery_table is None: return None recovery_describe_msg = { "role": "assistant", "tool_calls": [ { "type": "function", "function": { "name": "describe", "arguments": json.dumps({"table_name": recovery_table}), }, } ], } messages.append(recovery_describe_msg) obs = env.step(SQLAction(action_type="DESCRIBE", argument=recovery_table)) if obs.error: return None messages.append({"role": "tool", "content": obs.result or ""}) corrected_query_msg = { "role": "assistant", "tool_calls": [ { "type": "function", "function": { "name": "query", "arguments": json.dumps({"sql": gold_sql}), }, } ], } messages.append(corrected_query_msg) obs = env.step(SQLAction(action_type="QUERY", argument=gold_sql)) if obs.error: return None messages.append({"role": "tool", "content": obs.result or ""}) answer_type = question.get("answer_type", "string") gold_answer = _format_answer_for_model(str(question["gold_answer"]), answer_type) answer_msg = { "role": "assistant", "tool_calls": [ { "type": "function", "function": { "name": "answer", "arguments": json.dumps({"value": gold_answer}), }, } ], } messages.append(answer_msg) obs = env.step(SQLAction(action_type="ANSWER", argument=gold_answer)) messages.append({"role": "tool", "content": obs.result or ""}) messages.append({"role": "assistant", "content": "Task complete."}) if obs.reward is None or obs.reward < 0.5: return None return {"messages": messages, "tools": TOOL_DEFINITIONS} def select_diverse_subset( examples: list[dict], questions: list[dict], max_count: int = 100, seed: int = 42, ) -> list[dict]: """Select a diverse subset covering different databases/difficulties. Stratifies by database_name to ensure broad coverage, then caps at max_count. This avoids the peaked-policy problem where SFT on all data collapses reward variance for GRPO (RC-GRPO, ToolRL findings). """ import random if len(examples) <= max_count: return examples # Build question_text → metadata lookup q_meta = {} for q in questions: q_meta[q["question_text"]] = { "db": q.get("database_name", "unknown"), "difficulty": q.get("difficulty", "easy"), } # Group examples by database by_db: dict[str, list[dict]] = {} for ex in examples: user_msg = next( (m["content"] for m in ex["messages"] if m["role"] == "user"), "", ) db = "unknown" for qt, meta in q_meta.items(): if user_msg.startswith(qt): db = meta["db"] break by_db.setdefault(db, []).append(ex) # Round-robin sample from each database rng = random.Random(seed) for db_examples in by_db.values(): rng.shuffle(db_examples) selected: list[dict] = [] dbs = sorted(by_db.keys()) idx = 0 while len(selected) < max_count: added = False for db in dbs: if idx < len(by_db[db]) and len(selected) < max_count: selected.append(by_db[db][idx]) added = True idx += 1 if not added: break return selected def main() -> None: import argparse parser = argparse.ArgumentParser(description="Generate SFT trajectories") parser.add_argument( "--enable-thinking", action="store_true", help="Omit /no_think from system prompt (for thinking-mode training)", ) cli_args = parser.parse_args() # Update module-level SYSTEM_PROMPT based on flag global SYSTEM_PROMPT SYSTEM_PROMPT = get_system_prompt(enable_thinking=cli_args.enable_thinking) if cli_args.enable_thinking: print("Thinking mode: ON (no /no_think prefix)") questions_path = PROJECT_ROOT / "data" / "questions" / "questions_train.json" db_dir = PROJECT_ROOT / "data" / "databases" output_path = PROJECT_ROOT / "data" / "sft" / "sft_trajectories.json" # SFT warmup needs just enough to teach tool-calling format. # Research (RC-GRPO, ToolRL) shows training on all data creates # a peaked policy that collapses GRPO reward variance. MAX_SFT_QUESTIONS = 100 if not questions_path.exists(): print(f"Questions file not found: {questions_path}") print("Run scripts/download_spider_databases.py first.") sys.exit(1) with open(questions_path) as f: questions = json.load(f) env = SQLEnvironment( questions_path=str(questions_path), db_dir=str(db_dir), tokenizer=MockTokenizer(), step_budget=15, ) all_examples: list[dict] = [] errors = 0 for i, question in enumerate(questions): try: example = generate_trajectory(env, question) if example is not None: all_examples.append(example) else: errors += 1 except Exception as e: errors += 1 if i < 5: print(f"Error on question {i} ({question.get('question_id', i)}): {e}") if (i + 1) % 50 == 0: print( f" Processed {i + 1}/{len(questions)}: " f"{len(all_examples)} trajectories, " f"{errors} failed" ) selected = select_diverse_subset( all_examples, questions, max_count=MAX_SFT_QUESTIONS, ) # Generate error-recovery trajectories ERROR_RECOVERY_TARGET = 20 ERROR_RECOVERY_SEED = 123 happy_path_count = len(selected) import random recovery_candidates = list(questions) random.Random(ERROR_RECOVERY_SEED).shuffle(recovery_candidates) error_recovery_examples: list[dict] = [] recovery_errors = 0 for i, question in enumerate(recovery_candidates): if len(error_recovery_examples) >= ERROR_RECOVERY_TARGET: break try: example = generate_error_recovery_trajectory(env, question) if example is not None: error_recovery_examples.append(example) except Exception as e: recovery_errors += 1 if i < 5: print( "Error generating error-recovery trajectory " f"for question {question.get('question_id', i)}: {e}" ) selected.extend(error_recovery_examples) # Write output output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w") as f: json.dump(selected, f, indent=2) # Report turn distribution n_describe = sum( 1 for ex in selected for m in ex["messages"] if m.get("tool_calls", [{}])[0].get("function", {}).get("name") == "describe" if m["role"] == "assistant" ) n_query = sum( 1 for ex in selected for m in ex["messages"] if m.get("tool_calls", [{}])[0].get("function", {}).get("name") == "query" if m["role"] == "assistant" ) n_answer = sum( 1 for ex in selected for m in ex["messages"] if m.get("tool_calls", [{}])[0].get("function", {}).get("name") == "answer" if m["role"] == "assistant" ) print( f"\nDone: {len(selected)} multi-turn trajectories " f"(from {len(all_examples)} valid, {errors} failed)" ) print( "Composition: " f"{happy_path_count} happy-path + " f"{len(error_recovery_examples)} error-recovery " f"({recovery_errors} recovery failures)" ) print(f"Assistant turns: {n_describe} describe, {n_query} query, {n_answer} answer") print(f"Output: {output_path}") if __name__ == "__main__": main()