Spaces:
Sleeping
Sleeping
| """ | |
| QueryForge Baseline Inference Script | |
| βββββββββββββββββββββββββββββββββββββ | |
| Runs a Claude model as an agent against all 3 built-in tasks and reports | |
| a reproducible baseline score. | |
| Usage: | |
| # All tasks, default model (claude-haiku-4-5): | |
| python baseline.py | |
| # Specific model: | |
| python baseline.py --model claude-opus-4-6 | |
| # Single task: | |
| python baseline.py --task task_easy_syntax | |
| # More verbose output: | |
| python baseline.py --verbose | |
| Requirements: | |
| ANTHROPIC_API_KEY must be set in the environment. | |
| """ | |
| import argparse | |
| import os | |
| import re | |
| import sys | |
| import anthropic | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from models import SQLAction | |
| from server.queryforge_environment import QueryforgeEnvironment | |
| from tasks import REGISTRY | |
| # ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEFAULT_MODEL = "claude-haiku-4-5" | |
| SYSTEM_PROMPT = """\ | |
| You are an expert SQL engineer. You will be given a SQL debugging or \ | |
| optimisation challenge. Your job is to submit a corrected or improved SQL query. | |
| Rules: | |
| - Respond with ONLY a single SQL query inside a ```sql ... ``` code block. | |
| - Do not explain your reasoning outside the code block. | |
| - Do not include multiple statements (no semicolons except at the very end). | |
| - If you receive feedback on a previous attempt, use it to improve your query. | |
| """ | |
| # ββ SQL extraction βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _SQL_BLOCK = re.compile(r"```(?:sql)?\s*(.*?)```", re.DOTALL | re.IGNORECASE) | |
| def _extract_sql(text: str) -> str: | |
| """Pull the first SQL code block out of Claude's response.""" | |
| match = _SQL_BLOCK.search(text) | |
| if match: | |
| return match.group(1).strip() | |
| # Fallback: return the whole response stripped β better than crashing | |
| return text.strip() | |
| # ββ Formatting helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _hr(char="β", width=70): | |
| print(char * width) | |
| def _score_bar(score: float, width: int = 25) -> str: | |
| filled = int(score * width) | |
| bar = "β" * filled + "β" * (width - filled) | |
| return f"[{bar}] {score:.3f}" | |
| # ββ Per-task agent loop ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_task( | |
| task_id: str, | |
| model: str, | |
| client: anthropic.Anthropic, | |
| verbose: bool = False, | |
| ) -> dict: | |
| """ | |
| Run one episode of a single task. | |
| Returns a dict with keys: | |
| task_id, task_title, task_level, | |
| best_score, attempts, done | |
| """ | |
| env = QueryforgeEnvironment() | |
| obs = env.reset(task_id=task_id) | |
| if obs.done: | |
| # reset() returned an error (unknown task_id) | |
| print(f" ERROR: {obs.feedback}") | |
| return {"task_id": task_id, "best_score": 0.0, "attempts": 0, "done": False} | |
| print(f"\n Task : {obs.task_title} [{obs.task_level}] (max {env._current_task.max_steps} steps)") | |
| if verbose: | |
| print(f" ID : {obs.task_id}") | |
| # ββ Build initial conversation ββββββββββββββββββββββββββββββββββββββββββββ | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"Here is your SQL challenge:\n\n{obs.task_description}\n\n" | |
| "Provide your fixed SQL query." | |
| ), | |
| } | |
| ] | |
| step = 0 | |
| while not obs.done: | |
| step += 1 | |
| # ββ Call Claude βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with client.messages.stream( | |
| model=model, | |
| max_tokens=512, | |
| system=SYSTEM_PROMPT, | |
| messages=messages, | |
| ) as stream: | |
| response_text = "" | |
| for text in stream.text_stream: | |
| response_text += text | |
| sql = _extract_sql(response_text) | |
| if verbose: | |
| print(f"\n ββ Step {step}") | |
| short_sql = sql[:120] + ("β¦" if len(sql) > 120 else "") | |
| print(f" SQL: {short_sql}") | |
| # ββ Submit to environment βββββββββββββββββββββββββββββββββββββββββββββ | |
| obs = env.step(SQLAction(sql=sql)) | |
| score_bar = _score_bar(obs.reward or 0.0) | |
| status = "β DONE" if obs.done else f"step {step}/{env._current_task.max_steps}" | |
| print(f" [{status}] Score: {score_bar}") | |
| if verbose and obs.feedback: | |
| fb = obs.feedback[:200] + ("β¦" if len(obs.feedback) > 200 else "") | |
| print(f" Feedback: {fb}") | |
| if obs.done: | |
| break | |
| # ββ Append exchange to conversation for next attempt ββββββββββββββββββ | |
| messages.append({"role": "assistant", "content": response_text}) | |
| messages.append({ | |
| "role": "user", | |
| "content": ( | |
| f"Your query scored {obs.reward:.3f}. Here is the feedback:\n\n" | |
| f"{obs.feedback}\n\n" | |
| f"Hint: {obs.hint}\n\n" | |
| "Please try again with an improved SQL query." | |
| ), | |
| }) | |
| return { | |
| "task_id": task_id, | |
| "task_title": obs.task_title, | |
| "task_level": obs.task_level, | |
| "best_score": obs.best_score, | |
| "attempts": obs.attempt, | |
| "done": obs.done, | |
| } | |
| # ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser(description="QueryForge Baseline Inference") | |
| parser.add_argument( | |
| "--model", default=DEFAULT_MODEL, | |
| help=f"Anthropic model ID to use (default: {DEFAULT_MODEL})" | |
| ) | |
| parser.add_argument( | |
| "--task", default=None, | |
| help="Run a single task by ID instead of all built-in tasks" | |
| ) | |
| parser.add_argument( | |
| "--verbose", action="store_true", | |
| help="Print SQL queries and full feedback for each step" | |
| ) | |
| args = parser.parse_args() | |
| # ββ Validate API key ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| api_key = os.environ.get("ANTHROPIC_API_KEY") | |
| if not api_key: | |
| print("ERROR: ANTHROPIC_API_KEY is not set.") | |
| sys.exit(1) | |
| client = anthropic.Anthropic(api_key=api_key) | |
| # ββ Determine tasks to run ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if args.task: | |
| task_ids = [args.task] | |
| else: | |
| task_ids = ["task_easy_syntax", "task_medium_join", "task_hard_cte"] | |
| # ββ Header ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _hr() | |
| print(" QueryForge β Baseline Inference") | |
| print(f" Model : {args.model}") | |
| print(f" Tasks : {', '.join(task_ids)}") | |
| _hr() | |
| # ββ Run each task βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| results = [] | |
| for task_id in task_ids: | |
| print(f"\n{'β' * 70}") | |
| result = run_task(task_id, args.model, client, verbose=args.verbose) | |
| results.append(result) | |
| # ββ Results table βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"\n{'β' * 70}") | |
| print(" BASELINE RESULTS") | |
| print(f" Model: {args.model}") | |
| print(f"{'β' * 70}") | |
| print(f" {'Task':<28} {'Level':<8} {'Steps':>5} {'Best Score'}") | |
| print(f" {'β' * 28} {'β' * 8} {'β' * 5} {'β' * 30}") | |
| total_score = 0.0 | |
| for r in results: | |
| title = r.get("task_title", r["task_id"])[:27] | |
| level = r.get("task_level", "?") | |
| attempts = r.get("attempts", "?") | |
| score = r["best_score"] | |
| total_score += score | |
| bar = _score_bar(score) | |
| print(f" {title:<28} {level:<8} {attempts:>5} {bar}") | |
| avg = total_score / len(results) if results else 0.0 | |
| print(f"{'β' * 70}") | |
| print(f" {'AVERAGE':<28} {'':8} {'':5} {_score_bar(avg)}") | |
| print(f"{'β' * 70}\n") | |
| if __name__ == "__main__": | |
| main() | |