""" Task definitions and deterministic graders for the SQL Query Optimizer environment. Each task returns a TaskDef with: - id, name, difficulty - query: the broken/unoptimised SQL the agent must fix - schema_context: relevant DDL - description: what the agent must accomplish - grader(rewritten_query) -> GraderResult(score, breakdown, feedback) """ from __future__ import annotations import re import dataclasses from typing import Callable, Dict, Optional _MIN_SCORE_EPS = 0.001 _MAX_SCORE_EPS = 0.999 def _strict_open_score(value: float) -> float: return round(min(max(float(value), _MIN_SCORE_EPS), _MAX_SCORE_EPS), 3) @dataclasses.dataclass class GraderResult: score: float # 0.0 – 1.0 correctness: float = 0.0 performance: float = 0.0 style: float = 0.0 feedback: str = "" @dataclasses.dataclass class TaskDef: id: int name: str difficulty: str # easy | medium | hard description: str query: str schema_context: str hint: Optional[str] max_steps: int grader: Callable[[str], GraderResult] # ────────────────────────────────────────────────────────────────────────────── # Helpers # ────────────────────────────────────────────────────────────────────────────── def _normalise(sql: str) -> str: """Lower-case, collapse whitespace.""" return re.sub(r"\s+", " ", sql.lower().strip()) def _has(sql: str, *patterns: str) -> bool: s = _normalise(sql) return all(p in s for p in patterns) def _missing(sql: str, *patterns: str) -> bool: s = _normalise(sql) return any(p not in s for p in patterns) # ────────────────────────────────────────────────────────────────────────────── # Task 1 — Easy: Fix a broken JOIN (missing ON clause / wrong join type) # ────────────────────────────────────────────────────────────────────────────── _T1_SCHEMA = """ CREATE TABLE orders ( order_id INT PRIMARY KEY, customer_id INT NOT NULL, total DECIMAL(10,2), created_at TIMESTAMP ); CREATE TABLE customers ( customer_id INT PRIMARY KEY, name VARCHAR(255), email VARCHAR(255) ); """ _T1_QUERY = """ SELECT o.order_id, c.name, o.total FROM orders o, customers c WHERE o.total > 100; """ _T1_DESC = ( "The query uses an implicit cross-join (comma syntax) between `orders` and " "`customers` but never links the two tables. Rewrite it with an explicit " "INNER JOIN … ON o.customer_id = c.customer_id, keeping the WHERE filter." ) def _grade_task1(rewritten: str) -> GraderResult: s = _normalise(rewritten) fb: list[str] = [] correctness = 0.0 performance = 0.0 style = 0.0 # Correctness: must have explicit JOIN with the correct ON key if "inner join" in s or ("join" in s and "cross join" not in s): if "on" in s and "customer_id" in s: correctness = 1.0 else: correctness = 0.4 fb.append("JOIN present but ON clause with customer_id is missing.") else: fb.append("Still uses implicit cross-join or missing JOIN keyword.") # Correctness: must still filter total > 100 if "total > 100" in s or "total>100" in s: correctness = min(correctness + 0.0, correctness) # already captured else: correctness = max(correctness - 0.3, 0.0) fb.append("WHERE o.total > 100 filter has been removed.") # Performance: explicit join is better than implicit cross join performance = 1.0 if correctness >= 0.8 else 0.3 # Style: uses table aliases style = 0.5 if re.search(r"\bo\b", s) and re.search(r"\bc\b", s): style = 1.0 elif "select *" not in s: style = 0.7 score = round(correctness * 0.6 + performance * 0.25 + style * 0.15, 3) feedback = " ".join(fb) if fb else "Correct! The JOIN is properly formed." return GraderResult( score=_strict_open_score(score), correctness=correctness, performance=performance, style=style, feedback=feedback, ) # ────────────────────────────────────────────────────────────────────────────── # Task 2 — Medium: Eliminate N+1 correlated subquery # ────────────────────────────────────────────────────────────────────────────── _T2_SCHEMA = """ CREATE TABLE employees ( emp_id INT PRIMARY KEY, name VARCHAR(255), dept_id INT, salary DECIMAL(10,2) ); CREATE TABLE departments ( dept_id INT PRIMARY KEY, dept_name VARCHAR(255), budget DECIMAL(12,2) ); """ _T2_QUERY = """ SELECT e.name, (SELECT d.dept_name FROM departments d WHERE d.dept_id = e.dept_id) AS dept_name FROM employees e WHERE e.salary > 50000; """ _T2_DESC = ( "The query uses a correlated scalar subquery in the SELECT list that fires " "once per row (N+1 problem). Collapse it into a single LEFT JOIN … ON " "e.dept_id = d.dept_id, keeping the salary filter." ) def _grade_task2(rewritten: str) -> GraderResult: s = _normalise(rewritten) fb: list[str] = [] correctness = 0.0 performance = 0.0 style = 0.0 # Correctness: correlated subquery in SELECT must be gone has_correlated = bool( re.search(r"select\s+.*\(\s*select", s) or re.search(r"\(\s*select\b.*\bwhere\b.*=\s*e\.", s) ) if has_correlated: fb.append("Correlated subquery still present in SELECT list.") correctness = 0.1 else: correctness = 0.5 # Correctness: must join on dept_id if "join" in s and "dept_id" in s and "on" in s: correctness = min(correctness + 0.5, 1.0) else: fb.append("Missing JOIN departments ON dept_id.") correctness = max(correctness - 0.1, 0.0) # Correctness: salary filter preserved if "salary" not in s or ("salary > 50000" not in s and "salary>50000" not in s): correctness = max(correctness - 0.2, 0.0) fb.append("salary > 50000 filter is missing or incorrect.") # Performance: single pass vs N+1 performance = 1.0 if not has_correlated and "join" in s else 0.2 # Style: uses aliases, selects explicit columns style = 0.5 if "select *" not in s: style += 0.25 if re.search(r"\be\b|\bd\b", s): style += 0.25 score = round(correctness * 0.55 + performance * 0.30 + style * 0.15, 3) feedback = " ".join(fb) if fb else "Excellent! N+1 eliminated with a clean JOIN." return GraderResult( score=_strict_open_score(score), correctness=correctness, performance=performance, style=style, feedback=feedback, ) # ────────────────────────────────────────────────────────────────────────────── # Task 3 — Hard: Full optimisation (4 independent issues) # ────────────────────────────────────────────────────────────────────────────── _T3_SCHEMA = """ CREATE TABLE products ( product_id INT PRIMARY KEY, name VARCHAR(255), category VARCHAR(100), price DECIMAL(10,2), stock INT ); CREATE TABLE order_items ( item_id INT PRIMARY KEY, order_id INT, product_id INT, quantity INT, unit_price DECIMAL(10,2) ); """ _T3_QUERY = """ SELECT DISTINCT * FROM products p JOIN order_items oi ON p.product_id = oi.product_id WHERE CAST(p.price AS VARCHAR) LIKE '1%' AND p.category = 'Electronics' ORDER BY p.name; """ _T3_DESC = ( "The query has four problems: " "(1) DISTINCT is redundant because product_id is PK and the JOIN is 1-to-many — remove it. " "(2) SELECT * should list only needed columns: p.name, p.category, p.price, oi.quantity, oi.unit_price. " "(3) CAST(p.price AS VARCHAR) LIKE '1%' prevents index use — rewrite as p.price >= 100 AND p.price < 200. " "(4) Add a comment hinting an index on (category, price) would help." ) def _grade_task3(rewritten: str) -> GraderResult: s = _normalise(rewritten) fb: list[str] = [] sub_scores: Dict[str, float] = {} # Sub-criterion 1: DISTINCT removed (0.25) if "distinct" not in s: sub_scores["no_distinct"] = 0.25 else: sub_scores["no_distinct"] = 0.0 fb.append("DISTINCT still present — it's redundant here.") # Sub-criterion 2: SELECT * replaced with explicit columns (0.25) if "select *" not in s and all( col in s for col in ("p.name", "p.price", "oi.quantity") ): sub_scores["explicit_columns"] = 0.25 elif "select *" not in s: sub_scores["explicit_columns"] = 0.15 fb.append("SELECT * removed but explicit column list is incomplete.") else: sub_scores["explicit_columns"] = 0.0 fb.append("SELECT * still used — list explicit columns.") # Sub-criterion 3: CAST…LIKE replaced with range predicate (0.25) cast_gone = "cast(" not in s and "cast (" not in s has_price_range = ( ("price >= 100" in s or "price>=100" in s) and ("price < 200" in s or "price<200" in s) ) if cast_gone and has_price_range: sub_scores["sargable"] = 0.25 elif cast_gone: sub_scores["sargable"] = 0.12 fb.append("CAST removed but price range predicate (>= 100 AND < 200) is missing.") else: sub_scores["sargable"] = 0.0 fb.append("CAST(price AS VARCHAR) LIKE … still present — non-sargable predicate.") # Sub-criterion 4: index hint comment present (0.25) raw = rewritten.lower() if "index" in raw and ("category" in raw or "price" in raw): sub_scores["index_hint"] = 0.25 else: sub_scores["index_hint"] = 0.0 fb.append("Missing comment / hint about adding an index on (category, price).") total = sum(sub_scores.values()) correctness = min(sub_scores["no_distinct"] + sub_scores["explicit_columns"], 0.5) * 2 performance = min(sub_scores["sargable"] + sub_scores["index_hint"], 0.5) * 2 style = 1.0 if "select *" not in s else 0.0 feedback = " ".join(fb) if fb else "Perfect optimisation across all four dimensions!" return GraderResult( score=_strict_open_score(total), correctness=round(correctness, 3), performance=round(performance, 3), style=round(style, 3), feedback=feedback, ) # ────────────────────────────────────────────────────────────────────────────── # Registry # ────────────────────────────────────────────────────────────────────────────── TASKS: Dict[int, TaskDef] = { 1: TaskDef( id=1, name="fix-broken-join", difficulty="easy", description=_T1_DESC, query=_T1_QUERY.strip(), schema_context=_T1_SCHEMA.strip(), hint="Replace the comma-separated FROM list with an explicit INNER JOIN … ON.", max_steps=3, grader=_grade_task1, ), 2: TaskDef( id=2, name="eliminate-n-plus-one", difficulty="medium", description=_T2_DESC, query=_T2_QUERY.strip(), schema_context=_T2_SCHEMA.strip(), hint="Move the subquery out of the SELECT list and into a LEFT JOIN.", max_steps=4, grader=_grade_task2, ), 3: TaskDef( id=3, name="full-optimization", difficulty="hard", description=_T3_DESC, query=_T3_QUERY.strip(), schema_context=_T3_SCHEMA.strip(), hint=None, max_steps=5, grader=_grade_task3, ), } def get_task(task_id: int) -> TaskDef: if task_id not in TASKS: raise ValueError(f"Unknown task_id {task_id}. Valid: {list(TASKS.keys())}") return TASKS[task_id]