sql-query-optimizer / test_env.py
Param20h's picture
Upload folder using huggingface_hub
210535c verified
"""Quick smoke test for all 3 tasks."""
import sys, json
sys.path.insert(0, ".")
from env.environment import SQLOptimizerEnv
from env.models import Action
env = SQLOptimizerEnv()
# ── Task 1 ──────────────────────────────────────────────────────────────────
print("=== Task 1 (Easy): fix-broken-join ===")
obs = env.reset(1)
print(f" task: {obs.task_name}")
action = Action(
rewritten_query=(
"SELECT o.order_id, c.name, o.total "
"FROM orders o INNER JOIN customers c ON o.customer_id = c.customer_id "
"WHERE o.total > 100"
),
explanation="Replaced comma cross-join with INNER JOIN ON customer_id",
is_done=True,
)
obs2, reward, done, info = env.step(action)
print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}")
print(f" feedback: {reward.feedback}")
assert obs2.done == True, "done should be True"
assert info["grader_score"] >= 0.8, f"Expected >=0.8, got {info['grader_score']}"
# ── Task 2 ──────────────────────────────────────────────────────────────────
print()
print("=== Task 2 (Medium): eliminate-n-plus-one ===")
obs = env.reset(2)
print(f" task: {obs.task_name}")
action = Action(
rewritten_query=(
"SELECT e.name, d.dept_name "
"FROM employees e "
"LEFT JOIN departments d ON e.dept_id = d.dept_id "
"WHERE e.salary > 50000"
),
explanation="Replaced correlated subquery with a single LEFT JOIN",
is_done=True,
)
obs2, reward, done, info = env.step(action)
print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}")
print(f" feedback: {reward.feedback}")
assert info["grader_score"] >= 0.7, f"Expected >=0.7, got {info['grader_score']}"
# ── Task 3 ──────────────────────────────────────────────────────────────────
print()
print("=== Task 3 (Hard): full-optimization ===")
obs = env.reset(3)
print(f" task: {obs.task_name}")
action = Action(
rewritten_query=(
"-- Index hint: consider CREATE INDEX ON products(category, price)\n"
"SELECT p.name, p.category, p.price, oi.quantity, oi.unit_price\n"
"FROM products p\n"
"JOIN order_items oi ON p.product_id = oi.product_id\n"
"WHERE p.price >= 100 AND p.price < 200\n"
" AND p.category = 'Electronics'\n"
"ORDER BY p.name"
),
explanation="Removed DISTINCT and SELECT *, replaced CAST LIKE with range, added index hint",
is_done=True,
)
obs2, reward, done, info = env.step(action)
print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}")
print(f" feedback: {reward.feedback}")
assert info["grader_score"] >= 0.9, f"Expected >=0.9, got {info['grader_score']}"
# ── state() ─────────────────────────────────────────────────────────────────
print()
print("=== state() ===")
print(json.dumps(env.state(), indent=2))
# ── invalid action penalty ───────────────────────────────────────────────────
print()
print("=== Invalid action test ===")
env.reset(1)
obs2, reward, done, info = env.step(Action(rewritten_query="", explanation="", is_done=False))
print(f" step_reward={reward.score} is_invalid={info['is_invalid']}")
assert info["is_invalid"] == True, "Empty query should be flagged invalid"
print()
print("ALL TESTS PASSED")