Spaces:
Sleeping
Sleeping
| """Quick smoke test for all 3 tasks.""" | |
| import sys, json | |
| sys.path.insert(0, ".") | |
| from env.environment import SQLOptimizerEnv | |
| from env.models import Action | |
| env = SQLOptimizerEnv() | |
| # ββ Task 1 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("=== Task 1 (Easy): fix-broken-join ===") | |
| obs = env.reset(1) | |
| print(f" task: {obs.task_name}") | |
| action = Action( | |
| rewritten_query=( | |
| "SELECT o.order_id, c.name, o.total " | |
| "FROM orders o INNER JOIN customers c ON o.customer_id = c.customer_id " | |
| "WHERE o.total > 100" | |
| ), | |
| explanation="Replaced comma cross-join with INNER JOIN ON customer_id", | |
| is_done=True, | |
| ) | |
| obs2, reward, done, info = env.step(action) | |
| print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}") | |
| print(f" feedback: {reward.feedback}") | |
| assert obs2.done == True, "done should be True" | |
| assert info["grader_score"] >= 0.8, f"Expected >=0.8, got {info['grader_score']}" | |
| # ββ Task 2 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print() | |
| print("=== Task 2 (Medium): eliminate-n-plus-one ===") | |
| obs = env.reset(2) | |
| print(f" task: {obs.task_name}") | |
| action = Action( | |
| rewritten_query=( | |
| "SELECT e.name, d.dept_name " | |
| "FROM employees e " | |
| "LEFT JOIN departments d ON e.dept_id = d.dept_id " | |
| "WHERE e.salary > 50000" | |
| ), | |
| explanation="Replaced correlated subquery with a single LEFT JOIN", | |
| is_done=True, | |
| ) | |
| obs2, reward, done, info = env.step(action) | |
| print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}") | |
| print(f" feedback: {reward.feedback}") | |
| assert info["grader_score"] >= 0.7, f"Expected >=0.7, got {info['grader_score']}" | |
| # ββ Task 3 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print() | |
| print("=== Task 3 (Hard): full-optimization ===") | |
| obs = env.reset(3) | |
| print(f" task: {obs.task_name}") | |
| action = Action( | |
| rewritten_query=( | |
| "-- Index hint: consider CREATE INDEX ON products(category, price)\n" | |
| "SELECT p.name, p.category, p.price, oi.quantity, oi.unit_price\n" | |
| "FROM products p\n" | |
| "JOIN order_items oi ON p.product_id = oi.product_id\n" | |
| "WHERE p.price >= 100 AND p.price < 200\n" | |
| " AND p.category = 'Electronics'\n" | |
| "ORDER BY p.name" | |
| ), | |
| explanation="Removed DISTINCT and SELECT *, replaced CAST LIKE with range, added index hint", | |
| is_done=True, | |
| ) | |
| obs2, reward, done, info = env.step(action) | |
| print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}") | |
| print(f" feedback: {reward.feedback}") | |
| assert info["grader_score"] >= 0.9, f"Expected >=0.9, got {info['grader_score']}" | |
| # ββ state() βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print() | |
| print("=== state() ===") | |
| print(json.dumps(env.state(), indent=2)) | |
| # ββ invalid action penalty βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print() | |
| print("=== Invalid action test ===") | |
| env.reset(1) | |
| obs2, reward, done, info = env.step(Action(rewritten_query="", explanation="", is_done=False)) | |
| print(f" step_reward={reward.score} is_invalid={info['is_invalid']}") | |
| assert info["is_invalid"] == True, "Empty query should be flagged invalid" | |
| print() | |
| print("ALL TESTS PASSED") | |