| |
| """Generate SFT warmup trajectories from gold SQL answers. |
| |
| Runs the SQLEnvironment programmatically for each training question: |
| 1. describe() each table in tables_involved |
| 2. query() with the gold SQL |
| 3. answer() with the gold answer |
| |
| Captures real environment responses so the model learns what |
| describe output looks like and how to read query results. |
| |
| Usage: |
| uv run python scripts/generate_sft_data.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import sys |
| from pathlib import Path |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from sql_env.models import SQLAction |
| from sql_env.server.sql_environment import SQLEnvironment |
| from sql_env.server.mock_tokenizer import MockTokenizer |
| from sql_env.training.prompts import get_system_prompt |
| from sql_env.training.trl_adapter import get_tool_definitions |
|
|
|
|
| def _format_answer_for_model(gold_answer: str, answer_type: str) -> str: |
| """Format gold answer as human-readable text the model should produce. |
| |
| Converts Python literals (lists, list-of-lists) into the plain-text |
| formats that the verifier can parse: comma-separated for lists, |
| pipe-separated rows for tables. |
| """ |
| import ast |
|
|
| raw = str(gold_answer).strip() |
|
|
| if answer_type == "list" and raw.startswith("["): |
| try: |
| parsed = ast.literal_eval(raw) |
| if isinstance(parsed, list): |
| return ", ".join(str(item) for item in parsed) |
| except (ValueError, SyntaxError): |
| pass |
|
|
| if answer_type == "table" and raw.startswith("["): |
| try: |
| parsed = ast.literal_eval(raw) |
| if ( |
| isinstance(parsed, list) |
| and parsed |
| and isinstance(parsed[0], (list, tuple)) |
| ): |
| lines = [] |
| for row in parsed: |
| lines.append(" | ".join(str(cell) for cell in row)) |
| return "\n".join(lines) |
| except (ValueError, SyntaxError): |
| pass |
|
|
| return raw |
|
|
|
|
| def _make_wrong_query( |
| gold_sql: str, |
| table_columns: dict[str, list[str]], |
| ) -> tuple[str, str]: |
| """Mutate SQL to provoke a realistic execution error.""" |
| import random |
| import re |
|
|
| if not table_columns: |
| return gold_sql, "none" |
|
|
| sql = gold_sql |
| table_names = list(table_columns.keys()) |
| all_columns = { |
| column.lower() for columns in table_columns.values() for column in columns |
| } |
|
|
| table_pattern = re.compile( |
| r"\b(" + "|".join(re.escape(name) for name in table_names) + r")\b", |
| flags=re.IGNORECASE, |
| ) |
| table_matches = list(table_pattern.finditer(sql)) |
|
|
| column_candidates: list[str] = [] |
| for columns in table_columns.values(): |
| for column in columns: |
| if re.search(rf"\b{re.escape(column)}\b", sql, flags=re.IGNORECASE): |
| column_candidates.append(column) |
|
|
| strategies: list[str] = [] |
| if column_candidates: |
| strategies.append("wrong_column") |
| if table_matches: |
| strategies.append("wrong_table") |
| if re.search(r"\bJOIN\b", sql, flags=re.IGNORECASE): |
| strategies.append("missing_join") |
|
|
| if not strategies: |
| return gold_sql, "none" |
|
|
| mutation = random.choice(strategies) |
|
|
| if mutation == "wrong_column": |
| candidate = random.choice(column_candidates) |
| replacement = f"{candidate}_old" |
| while replacement.lower() in all_columns: |
| replacement += "_x" |
| wrong_sql = re.sub( |
| rf"\b{re.escape(candidate)}\b", |
| replacement, |
| sql, |
| count=1, |
| flags=re.IGNORECASE, |
| ) |
| return ( |
| (wrong_sql, "wrong_column") if wrong_sql != gold_sql else (gold_sql, "none") |
| ) |
|
|
| if mutation == "wrong_table": |
| chosen = random.choice(table_matches) |
| table_name = chosen.group(0) |
| replacement = f"{table_name}_v2" |
| while replacement.lower() in {name.lower() for name in table_names}: |
| replacement += "_x" |
| wrong_sql = re.sub( |
| rf"\b{re.escape(table_name)}\b", |
| replacement, |
| sql, |
| count=1, |
| flags=re.IGNORECASE, |
| ) |
| return ( |
| (wrong_sql, "wrong_table") if wrong_sql != gold_sql else (gold_sql, "none") |
| ) |
|
|
| join_clause = re.search( |
| r"\bJOIN\b\s+[^\s]+(?:\s+\w+)?\s+\bON\b\s+" |
| r"[^;]+?(?=\bJOIN\b|\bWHERE\b|\bGROUP\b|\bORDER\b|\bLIMIT\b|$)", |
| sql, |
| flags=re.IGNORECASE, |
| ) |
| if join_clause: |
| wrong_sql = sql[: join_clause.start()] + sql[join_clause.end() :] |
| return ( |
| (wrong_sql, "missing_join") if wrong_sql != gold_sql else (gold_sql, "none") |
| ) |
|
|
| return gold_sql, "none" |
|
|
|
|
| |
| SYSTEM_PROMPT = get_system_prompt(enable_thinking=False) |
|
|
| |
| |
| TOOL_DEFINITIONS = get_tool_definitions() |
|
|
|
|
| def generate_trajectory( |
| env: SQLEnvironment, |
| question: dict, |
| ) -> dict | None: |
| """Generate a full multi-turn SFT example from one question's gold trajectory. |
| |
| Returns a single example with ``messages`` containing the complete |
| conversation: system, user, then alternating assistant tool_calls and |
| tool responses for describe/query/answer. With ``assistant_only_loss`` |
| enabled in TRL, loss is computed on all assistant turns (not tool |
| responses), so the model learns the full describe→query→answer |
| workflow in one forward pass. |
| |
| NOTE: arguments must be JSON strings, not dicts. Qwen3's |
| apply_chat_template expands dicts to include all parameter names |
| from all tools with null values. |
| """ |
| matching = [ |
| q for q in env.questions if q.question_text == question["question_text"] |
| ] |
| if not matching: |
| return None |
|
|
| original_questions = list(env.questions) |
| try: |
| env.questions = matching |
| obs = env.reset(seed=None) |
| finally: |
| env.questions = original_questions |
|
|
| |
| tables_from_schema = [] |
| for line in (obs.schema_info or "").split("\n"): |
| stripped = line.strip().lstrip("- ").strip() |
| if stripped and stripped != "Available tables:": |
| tables_from_schema.append(stripped) |
| table_hint = ( |
| f"Tables: {', '.join(tables_from_schema)}. " |
| "Use describe, sample, query, and answer tools." |
| ) |
|
|
| user_content = f"{question['question_text']}" + table_hint |
| tools = TOOL_DEFINITIONS |
|
|
| messages: list[dict] = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_content}, |
| ] |
|
|
| |
| tables = question.get("tables_involved", []) |
| for table in tables: |
| assistant_msg = { |
| "role": "assistant", |
| "tool_calls": [ |
| { |
| "type": "function", |
| "function": { |
| "name": "describe", |
| "arguments": json.dumps({"table_name": table}), |
| }, |
| } |
| ], |
| } |
| messages.append(assistant_msg) |
|
|
| obs = env.step(SQLAction(action_type="DESCRIBE", argument=table)) |
| if obs.error: |
| return None |
| messages.append({"role": "tool", "content": obs.result}) |
|
|
| |
| gold_sql = question["gold_sql"] |
| query_msg = { |
| "role": "assistant", |
| "tool_calls": [ |
| { |
| "type": "function", |
| "function": { |
| "name": "query", |
| "arguments": json.dumps({"sql": gold_sql}), |
| }, |
| } |
| ], |
| } |
| messages.append(query_msg) |
|
|
| obs = env.step(SQLAction(action_type="QUERY", argument=gold_sql)) |
| if obs.error: |
| return None |
| messages.append({"role": "tool", "content": obs.result}) |
|
|
| |
| answer_type = question.get("answer_type", "string") |
| gold_answer = _format_answer_for_model(str(question["gold_answer"]), answer_type) |
| answer_msg = { |
| "role": "assistant", |
| "tool_calls": [ |
| { |
| "type": "function", |
| "function": { |
| "name": "answer", |
| "arguments": json.dumps({"value": gold_answer}), |
| }, |
| } |
| ], |
| } |
| messages.append(answer_msg) |
|
|
| obs = env.step(SQLAction(action_type="ANSWER", argument=gold_answer)) |
| messages.append({"role": "tool", "content": obs.result or ""}) |
| messages.append({"role": "assistant", "content": "Task complete."}) |
|
|
| |
| if obs.reward is None or obs.reward < 0.5: |
| return None |
|
|
| return {"messages": messages, "tools": tools} |
|
|
|
|
| def generate_error_recovery_trajectory( |
| env: SQLEnvironment, |
| question: dict, |
| ) -> dict | None: |
| """Generate a trajectory that demonstrates SQL error recovery.""" |
| matching = [ |
| q for q in env.questions if q.question_text == question["question_text"] |
| ] |
| if not matching: |
| return None |
|
|
| original_questions = list(env.questions) |
| try: |
| env.questions = matching |
| obs = env.reset(seed=None) |
| finally: |
| env.questions = original_questions |
|
|
| tables_from_schema = [] |
| for line in (obs.schema_info or "").split("\n"): |
| stripped = line.strip().lstrip("- ").strip() |
| if stripped and stripped != "Available tables:": |
| tables_from_schema.append(stripped) |
| table_hint = ( |
| f"Tables: {', '.join(tables_from_schema)}. " |
| "Use describe, sample, query, and answer tools." |
| ) |
|
|
| user_content = f"{question['question_text']}" + table_hint |
| messages: list[dict] = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_content}, |
| ] |
|
|
| tables = question.get("tables_involved", []) |
| table_columns: dict[str, list[str]] = {} |
| for table in tables: |
| describe_msg = { |
| "role": "assistant", |
| "tool_calls": [ |
| { |
| "type": "function", |
| "function": { |
| "name": "describe", |
| "arguments": json.dumps({"table_name": table}), |
| }, |
| } |
| ], |
| } |
| messages.append(describe_msg) |
|
|
| obs = env.step(SQLAction(action_type="DESCRIBE", argument=table)) |
| if obs.error: |
| return None |
| describe_result = obs.result or "" |
| messages.append({"role": "tool", "content": describe_result}) |
|
|
| parsed_columns: list[str] = [] |
| for row in describe_result.splitlines(): |
| stripped = row.strip() |
| if not stripped.startswith("- ") or ":" not in stripped: |
| continue |
| name = stripped[2:].split(":", 1)[0].strip() |
| if name: |
| parsed_columns.append(name) |
| if parsed_columns: |
| table_columns[table] = parsed_columns |
|
|
| gold_sql = question["gold_sql"] |
| wrong_sql, mutation_type = _make_wrong_query(gold_sql, table_columns) |
| if mutation_type == "none": |
| return None |
|
|
| wrong_query_msg = { |
| "role": "assistant", |
| "tool_calls": [ |
| { |
| "type": "function", |
| "function": { |
| "name": "query", |
| "arguments": json.dumps({"sql": wrong_sql}), |
| }, |
| } |
| ], |
| } |
| messages.append(wrong_query_msg) |
|
|
| obs = env.step(SQLAction(action_type="QUERY", argument=wrong_sql)) |
| if not obs.error: |
| return None |
| messages.append({"role": "tool", "content": obs.error}) |
|
|
| recovery_table = tables[0] if tables else next(iter(table_columns), None) |
| if recovery_table is None: |
| return None |
|
|
| recovery_describe_msg = { |
| "role": "assistant", |
| "tool_calls": [ |
| { |
| "type": "function", |
| "function": { |
| "name": "describe", |
| "arguments": json.dumps({"table_name": recovery_table}), |
| }, |
| } |
| ], |
| } |
| messages.append(recovery_describe_msg) |
|
|
| obs = env.step(SQLAction(action_type="DESCRIBE", argument=recovery_table)) |
| if obs.error: |
| return None |
| messages.append({"role": "tool", "content": obs.result or ""}) |
|
|
| corrected_query_msg = { |
| "role": "assistant", |
| "tool_calls": [ |
| { |
| "type": "function", |
| "function": { |
| "name": "query", |
| "arguments": json.dumps({"sql": gold_sql}), |
| }, |
| } |
| ], |
| } |
| messages.append(corrected_query_msg) |
|
|
| obs = env.step(SQLAction(action_type="QUERY", argument=gold_sql)) |
| if obs.error: |
| return None |
| messages.append({"role": "tool", "content": obs.result or ""}) |
|
|
| answer_type = question.get("answer_type", "string") |
| gold_answer = _format_answer_for_model(str(question["gold_answer"]), answer_type) |
| answer_msg = { |
| "role": "assistant", |
| "tool_calls": [ |
| { |
| "type": "function", |
| "function": { |
| "name": "answer", |
| "arguments": json.dumps({"value": gold_answer}), |
| }, |
| } |
| ], |
| } |
| messages.append(answer_msg) |
|
|
| obs = env.step(SQLAction(action_type="ANSWER", argument=gold_answer)) |
| messages.append({"role": "tool", "content": obs.result or ""}) |
| messages.append({"role": "assistant", "content": "Task complete."}) |
|
|
| if obs.reward is None or obs.reward < 0.5: |
| return None |
|
|
| return {"messages": messages, "tools": TOOL_DEFINITIONS} |
|
|
|
|
| def select_diverse_subset( |
| examples: list[dict], |
| questions: list[dict], |
| max_count: int = 100, |
| seed: int = 42, |
| ) -> list[dict]: |
| """Select a diverse subset covering different databases/difficulties. |
| |
| Stratifies by database_name to ensure broad coverage, then caps at |
| max_count. This avoids the peaked-policy problem where SFT on all |
| data collapses reward variance for GRPO (RC-GRPO, ToolRL findings). |
| """ |
| import random |
|
|
| if len(examples) <= max_count: |
| return examples |
|
|
| |
| q_meta = {} |
| for q in questions: |
| q_meta[q["question_text"]] = { |
| "db": q.get("database_name", "unknown"), |
| "difficulty": q.get("difficulty", "easy"), |
| } |
|
|
| |
| by_db: dict[str, list[dict]] = {} |
| for ex in examples: |
| user_msg = next( |
| (m["content"] for m in ex["messages"] if m["role"] == "user"), |
| "", |
| ) |
| db = "unknown" |
| for qt, meta in q_meta.items(): |
| if user_msg.startswith(qt): |
| db = meta["db"] |
| break |
| by_db.setdefault(db, []).append(ex) |
|
|
| |
| rng = random.Random(seed) |
| for db_examples in by_db.values(): |
| rng.shuffle(db_examples) |
|
|
| selected: list[dict] = [] |
| dbs = sorted(by_db.keys()) |
| idx = 0 |
| while len(selected) < max_count: |
| added = False |
| for db in dbs: |
| if idx < len(by_db[db]) and len(selected) < max_count: |
| selected.append(by_db[db][idx]) |
| added = True |
| idx += 1 |
| if not added: |
| break |
|
|
| return selected |
|
|
|
|
| def main() -> None: |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="Generate SFT trajectories") |
| parser.add_argument( |
| "--enable-thinking", |
| action="store_true", |
| help="Omit /no_think from system prompt (for thinking-mode training)", |
| ) |
| cli_args = parser.parse_args() |
|
|
| |
| global SYSTEM_PROMPT |
| SYSTEM_PROMPT = get_system_prompt(enable_thinking=cli_args.enable_thinking) |
| if cli_args.enable_thinking: |
| print("Thinking mode: ON (no /no_think prefix)") |
|
|
| questions_path = PROJECT_ROOT / "data" / "questions" / "questions_train.json" |
| db_dir = PROJECT_ROOT / "data" / "databases" |
| output_path = PROJECT_ROOT / "data" / "sft" / "sft_trajectories.json" |
|
|
| |
| |
| |
| MAX_SFT_QUESTIONS = 100 |
|
|
| if not questions_path.exists(): |
| print(f"Questions file not found: {questions_path}") |
| print("Run scripts/download_spider_databases.py first.") |
| sys.exit(1) |
|
|
| with open(questions_path) as f: |
| questions = json.load(f) |
|
|
| env = SQLEnvironment( |
| questions_path=str(questions_path), |
| db_dir=str(db_dir), |
| tokenizer=MockTokenizer(), |
| step_budget=15, |
| ) |
|
|
| all_examples: list[dict] = [] |
| errors = 0 |
|
|
| for i, question in enumerate(questions): |
| try: |
| example = generate_trajectory(env, question) |
| if example is not None: |
| all_examples.append(example) |
| else: |
| errors += 1 |
| except Exception as e: |
| errors += 1 |
| if i < 5: |
| print(f"Error on question {i} ({question.get('question_id', i)}): {e}") |
|
|
| if (i + 1) % 50 == 0: |
| print( |
| f" Processed {i + 1}/{len(questions)}: " |
| f"{len(all_examples)} trajectories, " |
| f"{errors} failed" |
| ) |
|
|
| selected = select_diverse_subset( |
| all_examples, |
| questions, |
| max_count=MAX_SFT_QUESTIONS, |
| ) |
|
|
| |
| ERROR_RECOVERY_TARGET = 20 |
| ERROR_RECOVERY_SEED = 123 |
| happy_path_count = len(selected) |
|
|
| import random |
|
|
| recovery_candidates = list(questions) |
| random.Random(ERROR_RECOVERY_SEED).shuffle(recovery_candidates) |
|
|
| error_recovery_examples: list[dict] = [] |
| recovery_errors = 0 |
| for i, question in enumerate(recovery_candidates): |
| if len(error_recovery_examples) >= ERROR_RECOVERY_TARGET: |
| break |
| try: |
| example = generate_error_recovery_trajectory(env, question) |
| if example is not None: |
| error_recovery_examples.append(example) |
| except Exception as e: |
| recovery_errors += 1 |
| if i < 5: |
| print( |
| "Error generating error-recovery trajectory " |
| f"for question {question.get('question_id', i)}: {e}" |
| ) |
|
|
| selected.extend(error_recovery_examples) |
|
|
| |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(output_path, "w") as f: |
| json.dump(selected, f, indent=2) |
|
|
| |
| n_describe = sum( |
| 1 |
| for ex in selected |
| for m in ex["messages"] |
| if m.get("tool_calls", [{}])[0].get("function", {}).get("name") == "describe" |
| if m["role"] == "assistant" |
| ) |
| n_query = sum( |
| 1 |
| for ex in selected |
| for m in ex["messages"] |
| if m.get("tool_calls", [{}])[0].get("function", {}).get("name") == "query" |
| if m["role"] == "assistant" |
| ) |
| n_answer = sum( |
| 1 |
| for ex in selected |
| for m in ex["messages"] |
| if m.get("tool_calls", [{}])[0].get("function", {}).get("name") == "answer" |
| if m["role"] == "assistant" |
| ) |
| print( |
| f"\nDone: {len(selected)} multi-turn trajectories " |
| f"(from {len(all_examples)} valid, {errors} failed)" |
| ) |
| print( |
| "Composition: " |
| f"{happy_path_count} happy-path + " |
| f"{len(error_recovery_examples)} error-recovery " |
| f"({recovery_errors} recovery failures)" |
| ) |
| print(f"Assistant turns: {n_describe} describe, {n_query} query, {n_answer} answer") |
| print(f"Output: {output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|