Spaces:
Sleeping
Sleeping
| """ | |
| merge_and_train.py | |
| ================== | |
| 1. Merges nl2sql_cleaned_ready_to_train.jsonl + edge_cases.jsonl | |
| 2. Shuffles the combined dataset | |
| 3. Retrains using the same GRPO setup as train.py | |
| Run: | |
| python merge_and_train.py | |
| Flags (env vars): | |
| EDGE_FILE β path to edge cases jsonl (default: edge_cases.jsonl) | |
| BASE_FILE β path to existing cleaned (default: nl2sql_cleaned_ready_to_train.jsonl) | |
| MERGED_FILE β merged output path (default: nl2sql_merged_final.jsonl) | |
| SKIP_MERGE β set "1" to skip merge step and go straight to training | |
| """ | |
| import os, sys, json, random | |
| import torch | |
| from datasets import Dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import LoraConfig | |
| from trl import GRPOConfig, GRPOTrainer | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0,5,1,6" | |
| sys.path.insert(0, "./server") | |
| from environment import NL2SQLEnvironment | |
| from models import NL2SQLAction | |
| from tasks import all_task_names, get_task | |
| # ββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_FILE = os.getenv("BASE_FILE", "nl2sql_cleaned_ready_to_train.jsonl") | |
| EDGE_FILE = os.getenv("EDGE_FILE", "edge_cases.jsonl") | |
| MERGED_FILE = os.getenv("MERGED_FILE", "nl2sql_merged_final.jsonl") | |
| SKIP_MERGE = os.getenv("SKIP_MERGE", "0") == "1" | |
| MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct" | |
| OUTPUT_DIR = "./qwen-7b-coder-nl2sql-grpo-v2" | |
| SYSTEM_PROMPT = """You are a Senior Database Architect and an expert in SQLite. | |
| Your task is to translate natural language questions into highly optimized, correct SQLite SELECT queries. | |
| STRICT RULES: | |
| 1. Output EXACTLY ONE valid SQLite query. | |
| 2. DO NOT wrap the query in markdown formatting (no ```sql or ```). | |
| 3. DO NOT output any explanations, conversational text, or preambles. | |
| 4. ONLY use standard SQLite functions. | |
| 5. If the question implies ordering, use the correct ORDER BY clause. | |
| 6. SELECT only the columns explicitly requested β no extras. | |
| Your output must be executable directly against the database as-is.""" | |
| # ββ Step 1: Merge βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def merge_datasets(): | |
| if SKIP_MERGE: | |
| print(f"[SKIP_MERGE=1] Using existing {MERGED_FILE}") | |
| return | |
| print(f"Loading base: {BASE_FILE}") | |
| print(f"Loading edges: {EDGE_FILE}") | |
| base_lines = [] | |
| with open(BASE_FILE, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| base_lines.append(line) | |
| edge_lines = [] | |
| with open(EDGE_FILE, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| edge_lines.append(line) | |
| combined = base_lines + edge_lines | |
| random.shuffle(combined) | |
| with open(MERGED_FILE, "w", encoding="utf-8") as f: | |
| for line in combined: | |
| f.write(line + "\n") | |
| print( | |
| f"Merged: {len(base_lines)} base + {len(edge_lines)} edge " | |
| f"= {len(combined)} total β {MERGED_FILE}" | |
| ) | |
| # ββ Step 2: Build HF Dataset ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_dataset(): | |
| """ | |
| Primary source: merged JSONL (base + edge cases). | |
| Fallback: task examples from server/tasks/ (same as original train.py). | |
| Both are combined so GRPO sees everything. | |
| """ | |
| data = [] | |
| # Load merged JSONL | |
| with open(MERGED_FILE, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| rec = json.loads(line) | |
| # rec has "prompt" (list of messages) and "sql" | |
| # GRPO needs "prompt" and "task_name" β we use a synthetic task_name | |
| data.append({ | |
| "prompt": rec["prompt"], | |
| "task_name": "merged_jsonl" # grader falls back to execution-based reward | |
| }) | |
| # Also keep the original task examples so GRPO reward env works for them | |
| for t_name in all_task_names(): | |
| task = get_task(t_name) | |
| schema = task.schema_context() | |
| for ex in task.examples: | |
| data.append({ | |
| "prompt": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"SCHEMA:\n{schema}\n\nQUESTION: {ex.question}"} | |
| ], | |
| "task_name": t_name | |
| }) | |
| random.shuffle(data) | |
| print(f"Dataset size: {len(data)} samples") | |
| return Dataset.from_list(data) | |
| # ββ Step 3: Reward function βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def sql_reward_func(prompts, completions, task_name, **kwargs): | |
| rewards = [] | |
| env = NL2SQLEnvironment() | |
| for idx, completion in enumerate(completions): | |
| generated = ( | |
| completion[0]["content"] if isinstance(completion, list) else completion | |
| ) | |
| # Strip code fences defensively | |
| import re | |
| generated = re.sub(r"```(?:sql)?\n?(.*?)```", r"\1", generated, flags=re.DOTALL).strip() | |
| t = task_name[idx] if isinstance(task_name, list) else task_name | |
| # For merged_jsonl rows the env won't have a matching task β | |
| # reward purely on execution (non-empty result set = +1, error = 0) | |
| if t == "merged_jsonl": | |
| rewards.append(_execution_reward(generated, prompts[idx])) | |
| continue | |
| env.reset(task_name=t) | |
| try: | |
| obs = env.step(NL2SQLAction(query=generated)) | |
| rewards.append(float(obs.reward)) | |
| except Exception: | |
| rewards.append(0.0) | |
| return rewards | |
| def _execution_reward(sql: str, prompt) -> float: | |
| """Simple execution check for merged_jsonl samples.""" | |
| import sqlite3, re as _re | |
| # Extract schema from the user message | |
| user_content = "" | |
| for msg in (prompt if isinstance(prompt, list) else []): | |
| if isinstance(msg, dict) and msg.get("role") == "user": | |
| user_content = msg.get("content", "") | |
| break | |
| schema_match = _re.search(r"SCHEMA:\s*(.*?)\nQUESTION:", user_content, _re.DOTALL) | |
| if not schema_match: | |
| return 0.5 # can't verify, neutral reward | |
| schema_sql = schema_match.group(1).strip() | |
| try: | |
| conn = sqlite3.connect(":memory:") | |
| conn.executescript(schema_sql) | |
| rows = conn.execute(sql).fetchall() | |
| conn.close() | |
| return 1.0 if rows else 0.3 # ran cleanly but empty β partial credit | |
| except Exception: | |
| return 0.0 | |
| # ββ Step 4: Train βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| merge_datasets() | |
| dataset = build_dataset() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="sdpa" | |
| ) | |
| peft_config = LoraConfig( | |
| r=128, | |
| lora_alpha=256, | |
| target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| training_args = GRPOConfig( | |
| output_dir=OUTPUT_DIR, | |
| learning_rate=1e-5, # lower LR for fine-grained edge case tuning | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=4, | |
| max_completion_length=256, | |
| num_generations=8, | |
| temperature=0.5, | |
| bf16=True, | |
| logging_steps=5, | |
| num_train_epochs=5, # fewer epochs β base knowledge already there | |
| report_to="none", | |
| remove_unused_columns=False, | |
| ddp_find_unused_parameters=False | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| reward_funcs=sql_reward_func, | |
| args=training_args, | |
| train_dataset=dataset, | |
| peft_config=peft_config, | |
| processing_class=tokenizer | |
| ) | |
| trainer.train() | |
| if trainer.accelerator.is_main_process: | |
| trainer.model.save_pretrained(f"{OUTPUT_DIR}/final") | |
| tokenizer.save_pretrained(f"{OUTPUT_DIR}/final") | |
| print(f"\nSaved to {OUTPUT_DIR}/final") | |
| if __name__ == "__main__": | |
| main() |