import asyncio import os import sys import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel # --- Configuration --- BASE_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" LORA_DIR = "./qwen-nl2sql-grpo/checkpoint-50" SPACE_URL = "http://localhost:8000" # Local server URL TASKS = ["simple-filter", "join-aggregation", "analytics-window"] MAX_STEPS = 5 print("Loading Base Model and LoRA weights...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto" ) model = PeftModel.from_pretrained(base_model, LORA_DIR) # --- System Prompt & LLM Call --- SYSTEM_PROMPT = """You are an expert SQL analyst working with a SQLite e-commerce database. Write a single SELECT query. Output ONLY the SQL query, nothing else. No markdown.""" def call_local_llm(user_prompt: str) -> str: messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt} ] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer([text], return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.2, do_sample=True) response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) # Strip markdown code fences if model wraps in ```sql ... ``` if response.startswith("```"): lines = response.split("\n") response = "\n".join(l for l in lines if not l.strip().startswith("```")).strip() return response if response else "SELECT 1" def build_user_prompt(question, schema_context, step, last_query, last_error, last_result, result_columns): parts = [f"QUESTION: {question}", ""] if step > 1: parts.append(f"Your previous SQL (step {step - 1}):") parts.append(f" {' '.join(last_query.split())}") parts.append("") if last_error: parts.append(f"ERROR: {last_error}") elif last_result: preview = str(last_result[:3]).replace("\n", " ") parts.append(f"RESULT PREVIEW (first 3 rows): {preview}") parts.append(f"COLUMNS: {result_columns}") parts.append("") parts.append("Please correct or refine your query.") else: parts.append("Write a SQL query to answer the question.") return "\n".join(parts) async def main(): from client import NL2SQLEnv, NL2SQLAction all_results = [] for task_name in TASKS: print(f"\n--- Starting Task: {task_name} ---") os.environ["NL2SQL_DEFAULT_TASK"] = task_name try: async with NL2SQLEnv(base_url=SPACE_URL) as env: result = await env.reset() obs = result.observation rewards = [] success = False for step in range(1, MAX_STEPS + 1): if obs.done: break user_prompt = build_user_prompt( obs.question, obs.schema_context, step, obs.last_query, obs.last_error, obs.last_result, obs.result_columns ) sql = call_local_llm(user_prompt) print(f"Step {step} Agent Output: {sql}") step_result = await env.step(NL2SQLAction(query=sql)) obs = step_result.observation reward = obs.reward or 0.0 rewards.append(reward) print(f"Step {step} Reward: {reward}") if obs.done: break score = sum(rewards) / max(len(rewards), 1) success = score >= 0.7 print(f"Final Score for {task_name}: {score:.3f}") all_results.append({"task": task_name, "score": score, "success": success}) except Exception as e: print(f"Error testing task {task_name}: {e}") print("\n=== Final Results ===") for r in all_results: print(f"{r['task']}: Score {r['score']:.3f} | Success: {r['success']}") if __name__ == "__main__": asyncio.run(main())