Spaces:
Sleeping
Sleeping
Initial submission: SQL Agent OpenEnv for Meta+HF hackathon
Browse files- .gitignore +8 -0
- Dockerfile +58 -0
- README.md +13 -0
- backend/api/__init__.py +0 -0
- backend/api/demo.py +495 -0
- backend/api/openenv.py +138 -0
- backend/data/.gitkeep +0 -0
- backend/data/benchmark.db +0 -0
- backend/env/__init__.py +0 -0
- backend/env/database.py +430 -0
- backend/env/sql_env.py +594 -0
- backend/env/tasks.py +345 -0
- backend/gepa/__init__.py +0 -0
- backend/gepa/optimizer.py +347 -0
- backend/main.py +104 -0
- backend/requirements.txt +9 -0
- backend/rl/__init__.py +0 -0
- backend/rl/environment.py +266 -0
- backend/rl/error_classifier.py +98 -0
- backend/rl/experience.py +208 -0
- backend/rl/grader.py +99 -0
- backend/rl/linucb.py +190 -0
- backend/rl/repair_strategies.py +219 -0
- backend/rl/types.py +161 -0
- frontend/index.html +14 -0
- frontend/package-lock.json +0 -0
- frontend/package.json +30 -0
- frontend/postcss.config.js +6 -0
- frontend/src/App.tsx +179 -0
- frontend/src/components/BenchmarkPanel.tsx +384 -0
- frontend/src/components/ChatPanel.tsx +599 -0
- frontend/src/components/ERDiagram.tsx +234 -0
- frontend/src/components/Header.tsx +110 -0
- frontend/src/components/LeftSidebar.tsx +157 -0
- frontend/src/components/PerformanceGraph.tsx +175 -0
- frontend/src/components/PromptEvolution.tsx +148 -0
- frontend/src/components/ResultsTable.tsx +78 -0
- frontend/src/components/RightSidebar.tsx +27 -0
- frontend/src/index.css +187 -0
- frontend/src/lib/api.ts +97 -0
- frontend/src/lib/types.ts +131 -0
- frontend/src/main.tsx +19 -0
- frontend/src/store/useStore.ts +175 -0
- frontend/src/vite-env.d.ts +9 -0
- frontend/tailwind.config.js +20 -0
- frontend/tsconfig.json +24 -0
- frontend/vite.config.ts +28 -0
- inference.py +230 -0
- openenv.yaml +137 -0
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
backend/data/rl_weights.json
|
| 5 |
+
backend/data/rl_experiences.json
|
| 6 |
+
backend/data/gepa_prompt.json
|
| 7 |
+
node_modules/
|
| 8 |
+
frontend/dist/
|
Dockerfile
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SQL Agent OpenEnv β Docker build for Hugging Face Spaces
|
| 2 |
+
#
|
| 3 |
+
# Stage 1: Build React frontend
|
| 4 |
+
# Stage 2: Python FastAPI app serving both the API and static UI
|
| 5 |
+
#
|
| 6 |
+
# HF Spaces expects the app to listen on port 7860.
|
| 7 |
+
|
| 8 |
+
# ββ Stage 1: Frontend build βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 9 |
+
FROM node:20-slim AS frontend-builder
|
| 10 |
+
|
| 11 |
+
WORKDIR /app/frontend
|
| 12 |
+
|
| 13 |
+
# Install deps first (layer cache)
|
| 14 |
+
COPY frontend/package.json frontend/package-lock.json* ./
|
| 15 |
+
RUN npm ci --prefer-offline --no-audit
|
| 16 |
+
|
| 17 |
+
# Build the app
|
| 18 |
+
COPY frontend/ ./
|
| 19 |
+
RUN npm run build
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ββ Stage 2: Python runtime βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
FROM python:3.11-slim
|
| 24 |
+
|
| 25 |
+
# System deps
|
| 26 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 27 |
+
gcc \
|
| 28 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 29 |
+
|
| 30 |
+
WORKDIR /app
|
| 31 |
+
|
| 32 |
+
# Python deps
|
| 33 |
+
COPY backend/requirements.txt ./requirements.txt
|
| 34 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 35 |
+
|
| 36 |
+
# Copy backend source
|
| 37 |
+
COPY backend/ ./backend/
|
| 38 |
+
|
| 39 |
+
# Copy built frontend
|
| 40 |
+
COPY --from=frontend-builder /app/frontend/dist ./frontend/dist
|
| 41 |
+
|
| 42 |
+
# Copy repo-root artefacts
|
| 43 |
+
COPY inference.py openenv.yaml README.md ./
|
| 44 |
+
|
| 45 |
+
# Ensure data dir exists (RL weights, GEPA prompts, SQLite DB)
|
| 46 |
+
RUN mkdir -p ./backend/data
|
| 47 |
+
|
| 48 |
+
# ββ HF Spaces config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
+
EXPOSE 7860
|
| 50 |
+
|
| 51 |
+
ENV PORT=7860 \
|
| 52 |
+
PYTHONUNBUFFERED=1 \
|
| 53 |
+
PYTHONDONTWRITEBYTECODE=1
|
| 54 |
+
|
| 55 |
+
# Run from backend/ so relative imports and data/ paths resolve correctly
|
| 56 |
+
WORKDIR /app/backend
|
| 57 |
+
|
| 58 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
README.md
CHANGED
|
@@ -1,4 +1,17 @@
|
|
| 1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
title: Sql Agent Openenv
|
| 3 |
emoji: π’
|
| 4 |
colorFrom: yellow
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SQL Agent OpenEnv
|
| 3 |
+
emoji: ποΈ
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
- sql
|
| 11 |
+
- reinforcement-learning
|
| 12 |
+
- contextual-bandit
|
| 13 |
+
---
|
| 14 |
+
---
|
| 15 |
title: Sql Agent Openenv
|
| 16 |
emoji: π’
|
| 17 |
colorFrom: yellow
|
backend/api/__init__.py
ADDED
|
File without changes
|
backend/api/demo.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Demo API routes β streaming SSE endpoints matching the original TypeScript API.
|
| 3 |
+
|
| 4 |
+
Routes:
|
| 5 |
+
GET /api/init
|
| 6 |
+
POST /api/execute-query (SSE)
|
| 7 |
+
POST /api/benchmark (SSE)
|
| 8 |
+
GET /api/rl-state
|
| 9 |
+
GET /api/schema-graph
|
| 10 |
+
POST /api/feedback
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
import json
|
| 17 |
+
import time
|
| 18 |
+
from typing import AsyncIterator, Optional
|
| 19 |
+
|
| 20 |
+
from fastapi import APIRouter
|
| 21 |
+
from pydantic import BaseModel
|
| 22 |
+
from sse_starlette.sse import EventSourceResponse
|
| 23 |
+
|
| 24 |
+
from env.database import (
|
| 25 |
+
ensure_seeded,
|
| 26 |
+
get_table_stats,
|
| 27 |
+
get_schema_info,
|
| 28 |
+
get_schema_graph,
|
| 29 |
+
execute_query,
|
| 30 |
+
)
|
| 31 |
+
from env.tasks import TASKS, get_task
|
| 32 |
+
from env.sql_env import SQLAgentEnv, Action, get_env, BASE_SYSTEM_PROMPT, _clean_sql
|
| 33 |
+
from rl.environment import get_bandit_state
|
| 34 |
+
from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME
|
| 35 |
+
from rl.error_classifier import classify_error, extract_offending_token
|
| 36 |
+
from rl.grader import GraderInput, compute_reward, compute_episode_reward
|
| 37 |
+
from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES
|
| 38 |
+
from gepa.optimizer import get_gepa, QueryResult
|
| 39 |
+
|
| 40 |
+
router = APIRouter()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# βββ /api/init ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
|
| 45 |
+
@router.get("/init")
|
| 46 |
+
async def init_db():
|
| 47 |
+
seeded = ensure_seeded()
|
| 48 |
+
tables = get_table_stats()
|
| 49 |
+
return {"tables": tables, "seeded": seeded}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# βββ /api/execute-query βββββββββββββββββββββββββββββββββββββββββββ
|
| 53 |
+
|
| 54 |
+
class ExecuteQueryRequest(BaseModel):
|
| 55 |
+
question: str
|
| 56 |
+
task_id: str = "simple_queries"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@router.post("/execute-query")
|
| 60 |
+
async def execute_query_stream(req: ExecuteQueryRequest):
|
| 61 |
+
async def event_generator() -> AsyncIterator[dict]:
|
| 62 |
+
env = get_env()
|
| 63 |
+
obs = env.reset(req.task_id)
|
| 64 |
+
|
| 65 |
+
# Pick first question of task matching question text, or default
|
| 66 |
+
task = get_task(req.task_id)
|
| 67 |
+
question_obj = task.questions[0]
|
| 68 |
+
# Override question text
|
| 69 |
+
env._episode.question = req.question # type: ignore[union-attr]
|
| 70 |
+
|
| 71 |
+
max_attempts = env.MAX_ATTEMPTS
|
| 72 |
+
done = False
|
| 73 |
+
all_step_rewards: list[float] = []
|
| 74 |
+
success = False
|
| 75 |
+
|
| 76 |
+
# Initial generate action
|
| 77 |
+
action = Action(repair_action="generate")
|
| 78 |
+
|
| 79 |
+
for attempt in range(1, max_attempts + 1):
|
| 80 |
+
yield {"data": json.dumps({"type": "attempt_start", "attempt": attempt})}
|
| 81 |
+
|
| 82 |
+
ep = env._episode # type: ignore[union-attr]
|
| 83 |
+
ep.attempt_number = attempt
|
| 84 |
+
|
| 85 |
+
# Generate SQL with streaming
|
| 86 |
+
from env.sql_env import _make_client, _MODEL
|
| 87 |
+
from openai import AsyncOpenAI
|
| 88 |
+
|
| 89 |
+
if attempt == 1 or ep.current_sql is None:
|
| 90 |
+
system_prompt = BASE_SYSTEM_PROMPT
|
| 91 |
+
user_msg = (
|
| 92 |
+
f"Schema:\n{obs.schema_info}\n\nQuestion: {req.question}\n\n"
|
| 93 |
+
"Write a SQL query to answer this question."
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message
|
| 97 |
+
from env.sql_env import REPAIR_ACTION_BY_NAME
|
| 98 |
+
|
| 99 |
+
# Bandit selects action
|
| 100 |
+
if ep.current_features is not None:
|
| 101 |
+
repair_enum, scores = env._bandit.select_action(ep.current_features)
|
| 102 |
+
ucb_scores = {
|
| 103 |
+
REPAIR_ACTION_NAMES[RepairAction(i)]: round(scores[i], 4)
|
| 104 |
+
for i in range(len(scores))
|
| 105 |
+
}
|
| 106 |
+
action = Action(repair_action=REPAIR_ACTION_NAMES[repair_enum])
|
| 107 |
+
yield {"data": json.dumps({
|
| 108 |
+
"type": "rl_action",
|
| 109 |
+
"action": action.repair_action,
|
| 110 |
+
"ucb_scores": ucb_scores,
|
| 111 |
+
})}
|
| 112 |
+
else:
|
| 113 |
+
repair_enum = RepairAction.REWRITE_FULL
|
| 114 |
+
action = Action(repair_action="rewrite_full")
|
| 115 |
+
|
| 116 |
+
suffix = get_repair_system_suffix(repair_enum)
|
| 117 |
+
offending = extract_offending_token(ep.error_message or "")
|
| 118 |
+
ctx = RepairContext(
|
| 119 |
+
schema=obs.schema_info,
|
| 120 |
+
question=req.question,
|
| 121 |
+
failing_sql=ep.current_sql or "",
|
| 122 |
+
error_message=ep.error_message or "",
|
| 123 |
+
offending_token=offending,
|
| 124 |
+
)
|
| 125 |
+
system_prompt = BASE_SYSTEM_PROMPT + suffix
|
| 126 |
+
user_msg = build_repair_user_message(repair_enum, ctx)
|
| 127 |
+
|
| 128 |
+
# Stream SQL generation
|
| 129 |
+
client = _make_client()
|
| 130 |
+
chunks: list[str] = []
|
| 131 |
+
try:
|
| 132 |
+
stream = await client.chat.completions.create(
|
| 133 |
+
model=_MODEL,
|
| 134 |
+
messages=[
|
| 135 |
+
{"role": "system", "content": system_prompt},
|
| 136 |
+
{"role": "user", "content": user_msg},
|
| 137 |
+
],
|
| 138 |
+
stream=True,
|
| 139 |
+
temperature=0.1,
|
| 140 |
+
)
|
| 141 |
+
async for chunk in stream:
|
| 142 |
+
delta = chunk.choices[0].delta.content
|
| 143 |
+
if delta:
|
| 144 |
+
chunks.append(delta)
|
| 145 |
+
yield {"data": json.dumps({"type": "sql_chunk", "chunk": delta})}
|
| 146 |
+
except Exception as e:
|
| 147 |
+
yield {"data": json.dumps({"type": "error", "error": str(e), "error_class": "other"})}
|
| 148 |
+
break
|
| 149 |
+
|
| 150 |
+
generated_sql = _clean_sql("".join(chunks))
|
| 151 |
+
yield {"data": json.dumps({"type": "sql_complete", "sql": generated_sql})}
|
| 152 |
+
yield {"data": json.dumps({"type": "executing"})}
|
| 153 |
+
|
| 154 |
+
rows, error = execute_query(generated_sql)
|
| 155 |
+
|
| 156 |
+
from env.tasks import grade_response
|
| 157 |
+
task_score = grade_response(
|
| 158 |
+
req.task_id, question_obj.id, generated_sql, rows, error, attempt
|
| 159 |
+
)
|
| 160 |
+
attempt_success = task_score >= 0.8
|
| 161 |
+
|
| 162 |
+
current_error_class = None
|
| 163 |
+
error_class_name = None
|
| 164 |
+
|
| 165 |
+
if error:
|
| 166 |
+
ec = classify_error(error)
|
| 167 |
+
current_error_class = ec
|
| 168 |
+
error_class_name = ERROR_CLASS_NAMES[ec]
|
| 169 |
+
|
| 170 |
+
error_changed = (
|
| 171 |
+
ep.previous_error_class is not None
|
| 172 |
+
and ep.previous_error_class != current_error_class
|
| 173 |
+
)
|
| 174 |
+
if ep.previous_error_class == current_error_class:
|
| 175 |
+
ep.consecutive_same_error += 1
|
| 176 |
+
else:
|
| 177 |
+
ep.consecutive_same_error = 1
|
| 178 |
+
|
| 179 |
+
rl_state = RLState(
|
| 180 |
+
error_class=current_error_class,
|
| 181 |
+
attempt_number=attempt,
|
| 182 |
+
previous_action=ep.last_action,
|
| 183 |
+
error_changed=error_changed,
|
| 184 |
+
consecutive_same_error=ep.consecutive_same_error,
|
| 185 |
+
)
|
| 186 |
+
ep.current_rl_state = rl_state
|
| 187 |
+
ep.current_features = featurize(rl_state)
|
| 188 |
+
|
| 189 |
+
# Stream diagnosis chunk
|
| 190 |
+
try:
|
| 191 |
+
diag_stream = await client.chat.completions.create(
|
| 192 |
+
model=_MODEL,
|
| 193 |
+
messages=[
|
| 194 |
+
{"role": "system", "content": "You are a SQL debugger. Briefly explain the error in one sentence."},
|
| 195 |
+
{"role": "user", "content": f"Error: {error}\nSQL: {generated_sql}"},
|
| 196 |
+
],
|
| 197 |
+
stream=True,
|
| 198 |
+
temperature=0.3,
|
| 199 |
+
)
|
| 200 |
+
async for chunk in diag_stream:
|
| 201 |
+
delta = chunk.choices[0].delta.content
|
| 202 |
+
if delta:
|
| 203 |
+
yield {"data": json.dumps({"type": "diagnosis_chunk", "chunk": delta})}
|
| 204 |
+
except Exception:
|
| 205 |
+
pass
|
| 206 |
+
|
| 207 |
+
yield {"data": json.dumps({"type": "error", "error": error, "error_class": error_class_name})}
|
| 208 |
+
|
| 209 |
+
# Grader + RL update
|
| 210 |
+
grader_in = GraderInput(
|
| 211 |
+
success=attempt_success,
|
| 212 |
+
attempt_number=attempt,
|
| 213 |
+
current_error_class=current_error_class,
|
| 214 |
+
previous_error_class=ep.previous_error_class,
|
| 215 |
+
)
|
| 216 |
+
grader_out = compute_reward(grader_in)
|
| 217 |
+
all_step_rewards.append(grader_out.reward)
|
| 218 |
+
|
| 219 |
+
if ep.current_rl_state and ep.current_features:
|
| 220 |
+
repair_enum_for_step = REPAIR_ACTION_BY_NAME.get(
|
| 221 |
+
action.repair_action, RepairAction.REWRITE_FULL
|
| 222 |
+
)
|
| 223 |
+
step_obj = EpisodeStep(
|
| 224 |
+
state=ep.current_rl_state,
|
| 225 |
+
featurized=ep.current_features,
|
| 226 |
+
action=repair_enum_for_step,
|
| 227 |
+
reward=grader_out.reward,
|
| 228 |
+
error_message=error or "",
|
| 229 |
+
sql=generated_sql,
|
| 230 |
+
success=attempt_success,
|
| 231 |
+
)
|
| 232 |
+
ep.steps.append(step_obj)
|
| 233 |
+
env._bandit.update(ep.current_features, repair_enum_for_step, grader_out.reward)
|
| 234 |
+
ep.last_action = repair_enum_for_step
|
| 235 |
+
|
| 236 |
+
ep.current_sql = generated_sql
|
| 237 |
+
ep.error_message = error
|
| 238 |
+
ep.error_class = error_class_name
|
| 239 |
+
ep.previous_error_class = current_error_class
|
| 240 |
+
|
| 241 |
+
yield {"data": json.dumps({
|
| 242 |
+
"type": "rl_reward",
|
| 243 |
+
"reward": grader_out.reward,
|
| 244 |
+
"breakdown": {
|
| 245 |
+
"base": grader_out.breakdown.base,
|
| 246 |
+
"attempt_penalty": grader_out.breakdown.attempt_penalty,
|
| 247 |
+
"severity_bonus": grader_out.breakdown.severity_bonus,
|
| 248 |
+
"change_bonus": grader_out.breakdown.change_bonus,
|
| 249 |
+
},
|
| 250 |
+
})}
|
| 251 |
+
|
| 252 |
+
if attempt_success:
|
| 253 |
+
success = True
|
| 254 |
+
yield {"data": json.dumps({
|
| 255 |
+
"type": "success",
|
| 256 |
+
"rows": rows,
|
| 257 |
+
"row_count": len(rows),
|
| 258 |
+
"sql": generated_sql,
|
| 259 |
+
})}
|
| 260 |
+
done = True
|
| 261 |
+
break
|
| 262 |
+
|
| 263 |
+
total_reward = compute_episode_reward(all_step_rewards, success)
|
| 264 |
+
yield {"data": json.dumps({
|
| 265 |
+
"type": "rl_episode_end",
|
| 266 |
+
"total_reward": total_reward,
|
| 267 |
+
"success": success,
|
| 268 |
+
})}
|
| 269 |
+
|
| 270 |
+
# Record GEPA history
|
| 271 |
+
gepa = get_gepa()
|
| 272 |
+
gepa.record_result(QueryResult(
|
| 273 |
+
question=req.question,
|
| 274 |
+
final_sql=env._episode.current_sql or "" if env._episode else "", # type: ignore[union-attr]
|
| 275 |
+
attempts=len(all_step_rewards),
|
| 276 |
+
success=success,
|
| 277 |
+
errors=[s.error_message for s in (env._episode.steps if env._episode else []) if s.error_message],
|
| 278 |
+
timestamp=time.time(),
|
| 279 |
+
))
|
| 280 |
+
|
| 281 |
+
# Finalize episode
|
| 282 |
+
env._finalize_episode(success=success)
|
| 283 |
+
if env._episode:
|
| 284 |
+
env._episode.done = True
|
| 285 |
+
env._episode.success = success
|
| 286 |
+
|
| 287 |
+
# Trigger GEPA if needed
|
| 288 |
+
if gepa.should_optimize():
|
| 289 |
+
try:
|
| 290 |
+
await gepa.run_optimization_cycle()
|
| 291 |
+
except Exception:
|
| 292 |
+
pass
|
| 293 |
+
|
| 294 |
+
return EventSourceResponse(event_generator())
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# βββ /api/benchmark βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 298 |
+
|
| 299 |
+
class BenchmarkRequest(BaseModel):
|
| 300 |
+
task_id: str = "simple_queries"
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
@router.post("/benchmark")
|
| 304 |
+
async def run_benchmark(req: BenchmarkRequest):
|
| 305 |
+
async def event_generator() -> AsyncIterator[dict]:
|
| 306 |
+
task = get_task(req.task_id)
|
| 307 |
+
scores: list[float] = []
|
| 308 |
+
|
| 309 |
+
for question_obj in task.questions:
|
| 310 |
+
yield {"data": json.dumps({
|
| 311 |
+
"type": "query_start",
|
| 312 |
+
"query_id": question_obj.id,
|
| 313 |
+
"question": question_obj.question,
|
| 314 |
+
})}
|
| 315 |
+
|
| 316 |
+
# Run the question through the env
|
| 317 |
+
env = SQLAgentEnv()
|
| 318 |
+
obs = env.reset_with_question(req.task_id, question_obj.id)
|
| 319 |
+
|
| 320 |
+
attempt = 0
|
| 321 |
+
sql = ""
|
| 322 |
+
success = False
|
| 323 |
+
task_score = 0.0
|
| 324 |
+
max_attempts = env.MAX_ATTEMPTS
|
| 325 |
+
ep = env._episode # type: ignore[union-attr]
|
| 326 |
+
|
| 327 |
+
gepa = get_gepa()
|
| 328 |
+
system_prompt = gepa.get_current_prompt()
|
| 329 |
+
from env.sql_env import _make_client, _MODEL
|
| 330 |
+
|
| 331 |
+
for attempt in range(1, max_attempts + 1):
|
| 332 |
+
ep.attempt_number = attempt
|
| 333 |
+
|
| 334 |
+
if attempt == 1 or ep.current_sql is None:
|
| 335 |
+
user_msg = (
|
| 336 |
+
f"Schema:\n{obs.schema_info}\n\n"
|
| 337 |
+
f"Question: {question_obj.question}\n\n"
|
| 338 |
+
"Write a SQL query to answer this question."
|
| 339 |
+
)
|
| 340 |
+
sys_prompt = system_prompt
|
| 341 |
+
else:
|
| 342 |
+
from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message
|
| 343 |
+
if ep.current_features is not None:
|
| 344 |
+
repair_enum, _ = env._bandit.select_action(ep.current_features)
|
| 345 |
+
else:
|
| 346 |
+
repair_enum = RepairAction.REWRITE_FULL
|
| 347 |
+
suffix = get_repair_system_suffix(repair_enum)
|
| 348 |
+
offending = extract_offending_token(ep.error_message or "")
|
| 349 |
+
ctx = RepairContext(
|
| 350 |
+
schema=obs.schema_info,
|
| 351 |
+
question=question_obj.question,
|
| 352 |
+
failing_sql=ep.current_sql or "",
|
| 353 |
+
error_message=ep.error_message or "",
|
| 354 |
+
offending_token=offending,
|
| 355 |
+
)
|
| 356 |
+
sys_prompt = system_prompt + suffix
|
| 357 |
+
user_msg = build_repair_user_message(repair_enum, ctx)
|
| 358 |
+
|
| 359 |
+
client = _make_client()
|
| 360 |
+
try:
|
| 361 |
+
resp = await client.chat.completions.create(
|
| 362 |
+
model=_MODEL,
|
| 363 |
+
messages=[
|
| 364 |
+
{"role": "system", "content": sys_prompt},
|
| 365 |
+
{"role": "user", "content": user_msg},
|
| 366 |
+
],
|
| 367 |
+
temperature=0.1,
|
| 368 |
+
)
|
| 369 |
+
sql = _clean_sql(resp.choices[0].message.content or "")
|
| 370 |
+
except Exception as e:
|
| 371 |
+
break
|
| 372 |
+
|
| 373 |
+
rows, error = execute_query(sql)
|
| 374 |
+
from env.tasks import grade_response
|
| 375 |
+
task_score = grade_response(
|
| 376 |
+
req.task_id, question_obj.id, sql, rows, error, attempt
|
| 377 |
+
)
|
| 378 |
+
success = task_score >= 0.8
|
| 379 |
+
|
| 380 |
+
current_ec = None
|
| 381 |
+
if error:
|
| 382 |
+
ec = classify_error(error)
|
| 383 |
+
current_ec = ec
|
| 384 |
+
error_changed = ep.previous_error_class is not None and ep.previous_error_class != ec
|
| 385 |
+
if ep.previous_error_class == ec:
|
| 386 |
+
ep.consecutive_same_error += 1
|
| 387 |
+
else:
|
| 388 |
+
ep.consecutive_same_error = 1
|
| 389 |
+
rl_state = RLState(
|
| 390 |
+
error_class=ec,
|
| 391 |
+
attempt_number=attempt,
|
| 392 |
+
previous_action=ep.last_action,
|
| 393 |
+
error_changed=error_changed,
|
| 394 |
+
consecutive_same_error=ep.consecutive_same_error,
|
| 395 |
+
)
|
| 396 |
+
ep.current_rl_state = rl_state
|
| 397 |
+
ep.current_features = featurize(rl_state)
|
| 398 |
+
|
| 399 |
+
from rl.grader import GraderInput, compute_reward
|
| 400 |
+
grader_in = GraderInput(
|
| 401 |
+
success=success,
|
| 402 |
+
attempt_number=attempt,
|
| 403 |
+
current_error_class=current_ec,
|
| 404 |
+
previous_error_class=ep.previous_error_class,
|
| 405 |
+
)
|
| 406 |
+
grader_out = compute_reward(grader_in)
|
| 407 |
+
|
| 408 |
+
ep.current_sql = sql
|
| 409 |
+
ep.error_message = error
|
| 410 |
+
ep.error_class = ERROR_CLASS_NAMES[current_ec] if current_ec else None
|
| 411 |
+
ep.previous_error_class = current_ec
|
| 412 |
+
|
| 413 |
+
if success:
|
| 414 |
+
break
|
| 415 |
+
|
| 416 |
+
scores.append(task_score)
|
| 417 |
+
|
| 418 |
+
yield {"data": json.dumps({
|
| 419 |
+
"type": "query_result",
|
| 420 |
+
"query_id": question_obj.id,
|
| 421 |
+
"success": success,
|
| 422 |
+
"score": task_score,
|
| 423 |
+
"sql": sql,
|
| 424 |
+
"attempts": attempt,
|
| 425 |
+
})}
|
| 426 |
+
|
| 427 |
+
overall_score = sum(scores) / len(scores) if scores else 0.0
|
| 428 |
+
yield {"data": json.dumps({
|
| 429 |
+
"type": "done",
|
| 430 |
+
"overall_score": overall_score,
|
| 431 |
+
"task_id": req.task_id,
|
| 432 |
+
})}
|
| 433 |
+
|
| 434 |
+
return EventSourceResponse(event_generator())
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# βββ /api/rl-state ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 438 |
+
|
| 439 |
+
@router.get("/rl-state")
|
| 440 |
+
async def get_rl_state():
|
| 441 |
+
state = get_bandit_state()
|
| 442 |
+
action_names = [REPAIR_ACTION_NAMES[RepairAction(i)] for i in range(8)]
|
| 443 |
+
action_distribution = {
|
| 444 |
+
name: state["action_counts"][i]
|
| 445 |
+
for i, name in enumerate(action_names)
|
| 446 |
+
}
|
| 447 |
+
return {
|
| 448 |
+
"action_counts": state["action_counts"],
|
| 449 |
+
"alpha": state["alpha"],
|
| 450 |
+
"total_updates": state["total_updates"],
|
| 451 |
+
"action_distribution": action_distribution,
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
# βββ /api/schema-graph ββββββββββββββββββββββββββββββββββββββββββββ
|
| 456 |
+
|
| 457 |
+
@router.get("/schema-graph")
|
| 458 |
+
async def schema_graph():
|
| 459 |
+
return get_schema_graph()
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
# βββ /api/feedback ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 463 |
+
|
| 464 |
+
class FeedbackRequest(BaseModel):
|
| 465 |
+
question: str
|
| 466 |
+
sql: str
|
| 467 |
+
correct: bool
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
@router.post("/feedback")
|
| 471 |
+
async def submit_feedback(req: FeedbackRequest):
|
| 472 |
+
gepa = get_gepa()
|
| 473 |
+
gepa.record_result(QueryResult(
|
| 474 |
+
question=req.question,
|
| 475 |
+
final_sql=req.sql,
|
| 476 |
+
attempts=1,
|
| 477 |
+
success=req.correct,
|
| 478 |
+
errors=[] if req.correct else ["User marked as incorrect"],
|
| 479 |
+
timestamp=time.time(),
|
| 480 |
+
))
|
| 481 |
+
|
| 482 |
+
result = None
|
| 483 |
+
if not req.correct and gepa.should_optimize():
|
| 484 |
+
try:
|
| 485 |
+
result = await gepa.run_optimization_cycle(
|
| 486 |
+
user_feedback_context=f"User marked query as incorrect.\nQuestion: {req.question}\nSQL: {req.sql}"
|
| 487 |
+
)
|
| 488 |
+
except Exception:
|
| 489 |
+
pass
|
| 490 |
+
|
| 491 |
+
return {
|
| 492 |
+
"received": True,
|
| 493 |
+
"gepa_triggered": result is not None,
|
| 494 |
+
"reflection": result.get("reflection") if result else None,
|
| 495 |
+
}
|
backend/api/openenv.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenEnv spec routes.
|
| 3 |
+
|
| 4 |
+
POST /env/reset β Observation
|
| 5 |
+
POST /env/step β {observation: Observation, reward: RewardInfo}
|
| 6 |
+
GET /env/state β current episode state dict
|
| 7 |
+
GET /env/tasks β list of task metadata
|
| 8 |
+
GET /env/info β env metadata
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from fastapi import APIRouter, HTTPException
|
| 14 |
+
from pydantic import BaseModel
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
from env.sql_env import get_env, Observation, Action, RewardInfo
|
| 18 |
+
from env.tasks import get_all_tasks
|
| 19 |
+
|
| 20 |
+
router = APIRouter()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# βββ Request Models βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
|
| 25 |
+
class ResetRequest(BaseModel):
|
| 26 |
+
task_id: str = "simple_queries"
|
| 27 |
+
question_id: Optional[str] = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class StepRequest(BaseModel):
|
| 31 |
+
repair_action: str = "generate"
|
| 32 |
+
custom_sql: Optional[str] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# βββ Routes βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
|
| 37 |
+
@router.post("/reset", response_model=Observation)
|
| 38 |
+
async def env_reset(req: ResetRequest):
|
| 39 |
+
"""Reset the environment to start a new episode."""
|
| 40 |
+
env = get_env()
|
| 41 |
+
if req.question_id:
|
| 42 |
+
obs = env.reset_with_question(req.task_id, req.question_id)
|
| 43 |
+
else:
|
| 44 |
+
obs = env.reset(req.task_id)
|
| 45 |
+
return obs
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@router.post("/step")
|
| 49 |
+
async def env_step(req: StepRequest):
|
| 50 |
+
"""Execute one step in the current episode."""
|
| 51 |
+
env = get_env()
|
| 52 |
+
try:
|
| 53 |
+
action = Action(
|
| 54 |
+
repair_action=req.repair_action,
|
| 55 |
+
custom_sql=req.custom_sql,
|
| 56 |
+
)
|
| 57 |
+
obs, reward = await env.step(action)
|
| 58 |
+
return {
|
| 59 |
+
"observation": obs.model_dump(),
|
| 60 |
+
"reward": reward.model_dump(),
|
| 61 |
+
}
|
| 62 |
+
except RuntimeError as e:
|
| 63 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@router.get("/state")
|
| 67 |
+
async def env_state():
|
| 68 |
+
"""Get the current episode state."""
|
| 69 |
+
env = get_env()
|
| 70 |
+
return env.state()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@router.get("/tasks")
|
| 74 |
+
async def list_tasks():
|
| 75 |
+
"""List all available tasks with metadata."""
|
| 76 |
+
tasks = get_all_tasks()
|
| 77 |
+
return [
|
| 78 |
+
{
|
| 79 |
+
"id": t.id,
|
| 80 |
+
"name": t.name,
|
| 81 |
+
"difficulty": t.difficulty,
|
| 82 |
+
"description": t.description,
|
| 83 |
+
"question_count": len(t.questions),
|
| 84 |
+
"questions": [
|
| 85 |
+
{
|
| 86 |
+
"id": q.id,
|
| 87 |
+
"question": q.question,
|
| 88 |
+
"hint_tables": q.hint_tables,
|
| 89 |
+
}
|
| 90 |
+
for q in t.questions
|
| 91 |
+
],
|
| 92 |
+
}
|
| 93 |
+
for t in tasks
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@router.get("/info")
|
| 98 |
+
async def env_info():
|
| 99 |
+
"""Return environment metadata (matches openenv.yaml spec)."""
|
| 100 |
+
return {
|
| 101 |
+
"name": "sql-agent-openenv",
|
| 102 |
+
"version": "1.0.0",
|
| 103 |
+
"description": "SQL generation and repair environment with RL-driven repair strategy selection.",
|
| 104 |
+
"action_space": {
|
| 105 |
+
"type": "discrete",
|
| 106 |
+
"actions": [
|
| 107 |
+
"generate",
|
| 108 |
+
"rewrite_full",
|
| 109 |
+
"fix_column",
|
| 110 |
+
"fix_table",
|
| 111 |
+
"add_groupby",
|
| 112 |
+
"rewrite_cte",
|
| 113 |
+
"fix_syntax",
|
| 114 |
+
"change_dialect",
|
| 115 |
+
"relax_filter",
|
| 116 |
+
],
|
| 117 |
+
},
|
| 118 |
+
"observation_space": {
|
| 119 |
+
"type": "dict",
|
| 120 |
+
"fields": [
|
| 121 |
+
"question",
|
| 122 |
+
"schema_info",
|
| 123 |
+
"current_sql",
|
| 124 |
+
"error_message",
|
| 125 |
+
"error_class",
|
| 126 |
+
"attempt_number",
|
| 127 |
+
"max_attempts",
|
| 128 |
+
"task_id",
|
| 129 |
+
"task_difficulty",
|
| 130 |
+
],
|
| 131 |
+
},
|
| 132 |
+
"reward_range": [-1.5, 1.5],
|
| 133 |
+
"max_steps": 5,
|
| 134 |
+
"tasks": ["simple_queries", "join_queries", "complex_queries"],
|
| 135 |
+
"rl_algorithm": "LinUCB (contextual bandit)",
|
| 136 |
+
"feature_dim": 20,
|
| 137 |
+
"num_actions": 8,
|
| 138 |
+
}
|
backend/data/.gitkeep
ADDED
|
File without changes
|
backend/data/benchmark.db
ADDED
|
Binary file (32.8 kB). View file
|
|
|
backend/env/__init__.py
ADDED
|
File without changes
|
backend/env/database.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQLite database setup and schema for the benchmark marketplace.
|
| 3 |
+
|
| 4 |
+
Tables:
|
| 5 |
+
sellers (id, name, email, country, rating)
|
| 6 |
+
users (id, name, email, created_at, country)
|
| 7 |
+
products (id, name, category, price, stock_quantity, seller_id)
|
| 8 |
+
orders (id, user_id, product_id, quantity, total_price, status, created_at)
|
| 9 |
+
reviews (id, user_id, product_id, rating, comment, created_at)
|
| 10 |
+
|
| 11 |
+
~50 rows per table of realistic seed data.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import sqlite3
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
_DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
|
| 22 |
+
DB_PATH = _DATA_DIR / "benchmark.db"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# βββ Schema βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
+
|
| 27 |
+
_DDL = """
|
| 28 |
+
CREATE TABLE IF NOT EXISTS sellers (
|
| 29 |
+
id INTEGER PRIMARY KEY,
|
| 30 |
+
name TEXT NOT NULL,
|
| 31 |
+
email TEXT NOT NULL UNIQUE,
|
| 32 |
+
country TEXT NOT NULL,
|
| 33 |
+
rating REAL NOT NULL DEFAULT 4.0
|
| 34 |
+
);
|
| 35 |
+
|
| 36 |
+
CREATE TABLE IF NOT EXISTS users (
|
| 37 |
+
id INTEGER PRIMARY KEY,
|
| 38 |
+
name TEXT NOT NULL,
|
| 39 |
+
email TEXT NOT NULL UNIQUE,
|
| 40 |
+
created_at TEXT NOT NULL,
|
| 41 |
+
country TEXT NOT NULL
|
| 42 |
+
);
|
| 43 |
+
|
| 44 |
+
CREATE TABLE IF NOT EXISTS products (
|
| 45 |
+
id INTEGER PRIMARY KEY,
|
| 46 |
+
name TEXT NOT NULL,
|
| 47 |
+
category TEXT NOT NULL,
|
| 48 |
+
price REAL NOT NULL,
|
| 49 |
+
stock_quantity INTEGER NOT NULL DEFAULT 0,
|
| 50 |
+
seller_id INTEGER NOT NULL REFERENCES sellers(id)
|
| 51 |
+
);
|
| 52 |
+
|
| 53 |
+
CREATE TABLE IF NOT EXISTS orders (
|
| 54 |
+
id INTEGER PRIMARY KEY,
|
| 55 |
+
user_id INTEGER NOT NULL REFERENCES users(id),
|
| 56 |
+
product_id INTEGER NOT NULL REFERENCES products(id),
|
| 57 |
+
quantity INTEGER NOT NULL DEFAULT 1,
|
| 58 |
+
total_price REAL NOT NULL,
|
| 59 |
+
status TEXT NOT NULL DEFAULT 'pending',
|
| 60 |
+
created_at TEXT NOT NULL
|
| 61 |
+
);
|
| 62 |
+
|
| 63 |
+
CREATE TABLE IF NOT EXISTS reviews (
|
| 64 |
+
id INTEGER PRIMARY KEY,
|
| 65 |
+
user_id INTEGER NOT NULL REFERENCES users(id),
|
| 66 |
+
product_id INTEGER NOT NULL REFERENCES products(id),
|
| 67 |
+
rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5),
|
| 68 |
+
comment TEXT,
|
| 69 |
+
created_at TEXT NOT NULL
|
| 70 |
+
);
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
# βββ Seed Data ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 74 |
+
|
| 75 |
+
_SELLERS = [
|
| 76 |
+
(1, "TechGadgets Inc", "contact@techgadgets.com", "USA", 4.8),
|
| 77 |
+
(2, "FashionHub", "info@fashionhub.co.uk", "UK", 4.5),
|
| 78 |
+
(3, "HomeDecor Pro", "sales@homedecopro.de", "Germany", 4.3),
|
| 79 |
+
(4, "SportZone", "hello@sportzone.fr", "France", 4.6),
|
| 80 |
+
(5, "BookWorld", "support@bookworld.ca", "Canada", 4.9),
|
| 81 |
+
(6, "ElectroMart", "contact@electromart.jp", "Japan", 4.7),
|
| 82 |
+
(7, "GreenGrocer", "team@greengrocer.au", "Australia", 4.4),
|
| 83 |
+
(8, "KidsToys Hub", "info@kidstoys.us", "USA", 4.2),
|
| 84 |
+
(9, "PetSupplies Co", "hello@petsupplies.nl", "Netherlands", 4.6),
|
| 85 |
+
(10, "OfficeSupply Plus", "contact@officesupply.sg", "Singapore", 4.1),
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
_USERS = [
|
| 89 |
+
(1, "Alice Johnson", "alice@example.com", "2023-01-15", "USA"),
|
| 90 |
+
(2, "Bob Smith", "bob@example.com", "2023-02-10", "UK"),
|
| 91 |
+
(3, "Carol White", "carol@example.com", "2023-03-05", "Canada"),
|
| 92 |
+
(4, "David Brown", "david@example.com", "2023-03-20", "Germany"),
|
| 93 |
+
(5, "Emma Davis", "emma@example.com", "2023-04-12", "France"),
|
| 94 |
+
(6, "Frank Miller", "frank@example.com", "2023-05-01", "Australia"),
|
| 95 |
+
(7, "Grace Wilson", "grace@example.com", "2023-05-18", "Japan"),
|
| 96 |
+
(8, "Henry Taylor", "henry@example.com", "2023-06-03", "USA"),
|
| 97 |
+
(9, "Isabella Anderson", "isabella@example.com", "2023-06-25", "UK"),
|
| 98 |
+
(10, "Jack Martinez", "jack@example.com", "2023-07-09", "Spain"),
|
| 99 |
+
(11, "Karen Thomas", "karen@example.com", "2023-07-22", "Italy"),
|
| 100 |
+
(12, "Liam Jackson", "liam@example.com", "2023-08-04", "Brazil"),
|
| 101 |
+
(13, "Mia Harris", "mia@example.com", "2023-08-17", "Canada"),
|
| 102 |
+
(14, "Noah Martin", "noah@example.com", "2023-09-01", "USA"),
|
| 103 |
+
(15, "Olivia Garcia", "olivia@example.com", "2023-09-14", "Mexico"),
|
| 104 |
+
(16, "Paul Robinson", "paul@example.com", "2023-10-02", "Australia"),
|
| 105 |
+
(17, "Quinn Lewis", "quinn@example.com", "2023-10-20", "New Zealand"),
|
| 106 |
+
(18, "Rachel Walker", "rachel@example.com", "2023-11-05", "UK"),
|
| 107 |
+
(19, "Sam Hall", "sam@example.com", "2023-11-19", "USA"),
|
| 108 |
+
(20, "Tina Allen", "tina@example.com", "2023-12-01", "Germany"),
|
| 109 |
+
(21, "Umar Young", "umar@example.com", "2024-01-08", "Pakistan"),
|
| 110 |
+
(22, "Vera Hernandez", "vera@example.com", "2024-01-22", "Spain"),
|
| 111 |
+
(23, "Will King", "will@example.com", "2024-02-06", "USA"),
|
| 112 |
+
(24, "Xena Wright", "xena@example.com", "2024-02-20", "Canada"),
|
| 113 |
+
(25, "Yusuf Lopez", "yusuf@example.com", "2024-03-05", "Morocco"),
|
| 114 |
+
(26, "Zoe Hill", "zoe@example.com", "2024-03-19", "UK"),
|
| 115 |
+
(27, "Aaron Scott", "aaron@example.com", "2024-04-02", "USA"),
|
| 116 |
+
(28, "Bella Green", "bella@example.com", "2024-04-16", "Australia"),
|
| 117 |
+
(29, "Carlos Adams", "carlos@example.com", "2024-05-01", "Brazil"),
|
| 118 |
+
(30, "Diana Baker", "diana@example.com", "2024-05-15", "Canada"),
|
| 119 |
+
(31, "Ethan Gonzalez", "ethan@example.com", "2024-05-29", "USA"),
|
| 120 |
+
(32, "Fatima Nelson", "fatima@example.com", "2024-06-12", "Nigeria"),
|
| 121 |
+
(33, "George Carter", "george@example.com", "2024-06-26", "UK"),
|
| 122 |
+
(34, "Hannah Mitchell", "hannah@example.com", "2024-07-10", "Germany"),
|
| 123 |
+
(35, "Ivan Perez", "ivan@example.com", "2024-07-24", "Russia"),
|
| 124 |
+
(36, "Julia Roberts", "juliar@example.com", "2024-08-07", "USA"),
|
| 125 |
+
(37, "Kevin Turner", "kevin@example.com", "2024-08-21", "Canada"),
|
| 126 |
+
(38, "Luna Phillips", "luna@example.com", "2024-09-04", "France"),
|
| 127 |
+
(39, "Mike Campbell", "mike@example.com", "2024-09-18", "USA"),
|
| 128 |
+
(40, "Nancy Parker", "nancy@example.com", "2024-10-02", "Japan"),
|
| 129 |
+
(41, "Oscar Evans", "oscar@example.com", "2024-10-16", "UK"),
|
| 130 |
+
(42, "Penny Edwards", "penny@example.com", "2024-10-30", "Australia"),
|
| 131 |
+
(43, "Roy Collins", "roy@example.com", "2024-11-13", "USA"),
|
| 132 |
+
(44, "Sara Stewart", "sara@example.com", "2024-11-27", "Canada"),
|
| 133 |
+
(45, "Tom Morris", "tom@example.com", "2024-12-11", "UK"),
|
| 134 |
+
(46, "Uma Rogers", "uma@example.com", "2024-12-25", "India"),
|
| 135 |
+
(47, "Victor Reed", "victor@example.com", "2025-01-08", "USA"),
|
| 136 |
+
(48, "Wendy Cook", "wendy@example.com", "2025-01-22", "Germany"),
|
| 137 |
+
(49, "Xavier Morgan", "xavier@example.com", "2025-02-05", "France"),
|
| 138 |
+
(50, "Yasmin Bell", "yasmin@example.com", "2025-02-19", "UK"),
|
| 139 |
+
]
|
| 140 |
+
|
| 141 |
+
_PRODUCTS = [
|
| 142 |
+
(1, "Wireless Headphones Pro", "Electronics", 149.99, 120, 1),
|
| 143 |
+
(2, "Laptop Stand Adjustable", "Electronics", 49.99, 200, 1),
|
| 144 |
+
(3, "USB-C Hub 7-in-1", "Electronics", 39.99, 350, 6),
|
| 145 |
+
(4, "Mechanical Keyboard RGB", "Electronics", 89.99, 85, 6),
|
| 146 |
+
(5, "Webcam 4K Ultra", "Electronics", 129.99, 60, 1),
|
| 147 |
+
(6, "Summer Floral Dress", "Fashion", 59.99, 180, 2),
|
| 148 |
+
(7, "Men Slim Fit Chinos", "Fashion", 44.99, 220, 2),
|
| 149 |
+
(8, "Leather Wallet Bifold", "Fashion", 34.99, 300, 2),
|
| 150 |
+
(9, "Running Shoes Ultralight", "Fashion", 109.99, 95, 4),
|
| 151 |
+
(10, "Yoga Pants High Waist", "Fashion", 54.99, 150, 4),
|
| 152 |
+
(11, "Ceramic Vase Set", "Home & Garden", 79.99, 70, 3),
|
| 153 |
+
(12, "Bamboo Cutting Board", "Home & Garden", 29.99, 400, 3),
|
| 154 |
+
(13, "Scented Candle Collection", "Home & Garden", 24.99, 500, 3),
|
| 155 |
+
(14, "Smart LED Bulb Pack", "Home & Garden", 59.99, 250, 1),
|
| 156 |
+
(15, "Coffee Table Book Stand", "Home & Garden", 49.99, 130, 3),
|
| 157 |
+
(16, "Protein Powder Vanilla", "Sports & Fitness", 54.99, 210, 4),
|
| 158 |
+
(17, "Resistance Band Set", "Sports & Fitness", 24.99, 600, 4),
|
| 159 |
+
(18, "Yoga Mat Non-Slip", "Sports & Fitness", 39.99, 300, 4),
|
| 160 |
+
(19, "Tennis Racket Pro", "Sports & Fitness", 89.99, 45, 4),
|
| 161 |
+
(20, "Water Bottle Insulated", "Sports & Fitness", 29.99, 450, 7),
|
| 162 |
+
(21, "The Python Handbook", "Books", 29.99, 200, 5),
|
| 163 |
+
(22, "Machine Learning Basics", "Books", 34.99, 175, 5),
|
| 164 |
+
(23, "Data Structures Guide", "Books", 27.99, 220, 5),
|
| 165 |
+
(24, "Mystery Novel Collection", "Books", 49.99, 100, 5),
|
| 166 |
+
(25, "Children Story Box Set", "Books", 44.99, 130, 8),
|
| 167 |
+
(26, "Dog Bed Orthopedic", "Pet Supplies", 79.99, 90, 9),
|
| 168 |
+
(27, "Cat Scratching Post", "Pet Supplies", 34.99, 170, 9),
|
| 169 |
+
(28, "Fish Tank Starter Kit", "Pet Supplies", 59.99, 55, 9),
|
| 170 |
+
(29, "Bird Cage Deluxe", "Pet Supplies", 89.99, 35, 9),
|
| 171 |
+
(30, "Pet Grooming Kit", "Pet Supplies", 39.99, 140, 9),
|
| 172 |
+
(31, "LEGO City Set 600pcs", "Toys", 69.99, 80, 8),
|
| 173 |
+
(32, "Remote Control Car", "Toys", 49.99, 120, 8),
|
| 174 |
+
(33, "Board Game Strategy", "Toys", 34.99, 200, 8),
|
| 175 |
+
(34, "Puzzle 1000 Pieces", "Toys", 24.99, 350, 8),
|
| 176 |
+
(35, "Art & Craft Kit Kids", "Toys", 29.99, 280, 8),
|
| 177 |
+
(36, "Office Desk Organizer", "Office", 39.99, 300, 10),
|
| 178 |
+
(37, "Wireless Mouse Ergonomic", "Electronics", 59.99, 200, 6),
|
| 179 |
+
(38, "Notebook Set Premium", "Office", 19.99, 600, 10),
|
| 180 |
+
(39, "Sticky Notes Colorful", "Office", 9.99, 800, 10),
|
| 181 |
+
(40, "Printer Paper Ream", "Office", 14.99, 500, 10),
|
| 182 |
+
(41, "Smart Watch Fitness", "Electronics", 199.99, 75, 1),
|
| 183 |
+
(42, "Blender High Power", "Home & Garden", 89.99, 110, 3),
|
| 184 |
+
(43, "Air Purifier HEPA", "Home & Garden", 149.99, 65, 1),
|
| 185 |
+
(44, "Backpack Waterproof", "Fashion", 79.99, 160, 2),
|
| 186 |
+
(45, "Sunglasses Polarized", "Fashion", 69.99, 200, 2),
|
| 187 |
+
(46, "Dumbbells Set 20kg", "Sports & Fitness", 79.99, 85, 4),
|
| 188 |
+
(47, "Jump Rope Speed", "Sports & Fitness", 19.99, 400, 4),
|
| 189 |
+
(48, "Graphic Novel Bundle", "Books", 59.99, 90, 5),
|
| 190 |
+
(49, "Phone Stand Adjustable", "Electronics", 24.99, 350, 6),
|
| 191 |
+
(50, "Desk Lamp LED", "Office", 44.99, 230, 10),
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
_ORDERS = [
|
| 195 |
+
(1, 1, 1, 1, 149.99, "delivered", "2024-01-10"),
|
| 196 |
+
(2, 2, 6, 2, 119.98, "delivered", "2024-01-15"),
|
| 197 |
+
(3, 3, 21, 1, 29.99, "delivered", "2024-01-20"),
|
| 198 |
+
(4, 4, 11, 1, 79.99, "delivered", "2024-01-25"),
|
| 199 |
+
(5, 5, 16, 2, 109.98, "delivered", "2024-02-01"),
|
| 200 |
+
(6, 6, 31, 1, 69.99, "delivered", "2024-02-05"),
|
| 201 |
+
(7, 7, 3, 2, 79.98, "shipped", "2024-02-10"),
|
| 202 |
+
(8, 8, 41, 1, 199.99, "delivered", "2024-02-14"),
|
| 203 |
+
(9, 9, 26, 1, 79.99, "delivered", "2024-02-18"),
|
| 204 |
+
(10, 10, 17, 3, 74.97, "delivered", "2024-02-22"),
|
| 205 |
+
(11, 11, 22, 1, 34.99, "delivered", "2024-03-01"),
|
| 206 |
+
(12, 12, 7, 1, 44.99, "delivered", "2024-03-05"),
|
| 207 |
+
(13, 13, 18, 2, 79.98, "delivered", "2024-03-10"),
|
| 208 |
+
(14, 14, 37, 1, 59.99, "shipped", "2024-03-14"),
|
| 209 |
+
(15, 15, 44, 1, 79.99, "delivered", "2024-03-18"),
|
| 210 |
+
(16, 16, 2, 1, 49.99, "delivered", "2024-03-22"),
|
| 211 |
+
(17, 17, 50, 1, 44.99, "pending", "2024-03-26"),
|
| 212 |
+
(18, 18, 5, 1, 129.99, "delivered", "2024-04-01"),
|
| 213 |
+
(19, 19, 12, 2, 59.98, "delivered", "2024-04-05"),
|
| 214 |
+
(20, 20, 33, 1, 34.99, "delivered", "2024-04-09"),
|
| 215 |
+
(21, 21, 9, 1, 109.99, "delivered", "2024-04-13"),
|
| 216 |
+
(22, 22, 14, 2, 119.98, "delivered", "2024-04-17"),
|
| 217 |
+
(23, 23, 43, 1, 149.99, "shipped", "2024-04-21"),
|
| 218 |
+
(24, 24, 25, 1, 44.99, "delivered", "2024-04-25"),
|
| 219 |
+
(25, 25, 8, 2, 69.98, "delivered", "2024-04-29"),
|
| 220 |
+
(26, 26, 4, 1, 89.99, "delivered", "2024-05-03"),
|
| 221 |
+
(27, 27, 29, 1, 89.99, "delivered", "2024-05-07"),
|
| 222 |
+
(28, 28, 20, 3, 89.97, "delivered", "2024-05-11"),
|
| 223 |
+
(29, 29, 35, 2, 59.98, "delivered", "2024-05-15"),
|
| 224 |
+
(30, 30, 46, 1, 79.99, "pending", "2024-05-19"),
|
| 225 |
+
(31, 31, 13, 5, 124.95, "delivered", "2024-05-23"),
|
| 226 |
+
(32, 32, 36, 2, 79.98, "delivered", "2024-05-27"),
|
| 227 |
+
(33, 33, 48, 1, 59.99, "delivered", "2024-05-31"),
|
| 228 |
+
(34, 34, 1, 1, 149.99, "delivered", "2024-06-04"),
|
| 229 |
+
(35, 35, 24, 1, 49.99, "delivered", "2024-06-08"),
|
| 230 |
+
(36, 36, 10, 2, 109.98, "shipped", "2024-06-12"),
|
| 231 |
+
(37, 37, 42, 1, 89.99, "delivered", "2024-06-16"),
|
| 232 |
+
(38, 38, 27, 1, 34.99, "delivered", "2024-06-20"),
|
| 233 |
+
(39, 39, 6, 1, 59.99, "delivered", "2024-06-24"),
|
| 234 |
+
(40, 40, 41, 1, 199.99, "delivered", "2024-06-28"),
|
| 235 |
+
(41, 41, 19, 1, 89.99, "cancelled", "2024-07-02"),
|
| 236 |
+
(42, 42, 34, 2, 49.98, "delivered", "2024-07-06"),
|
| 237 |
+
(43, 43, 23, 1, 27.99, "delivered", "2024-07-10"),
|
| 238 |
+
(44, 44, 47, 3, 59.97, "delivered", "2024-07-14"),
|
| 239 |
+
(45, 45, 15, 1, 49.99, "delivered", "2024-07-18"),
|
| 240 |
+
(46, 46, 32, 1, 49.99, "delivered", "2024-07-22"),
|
| 241 |
+
(47, 47, 3, 1, 39.99, "pending", "2024-07-26"),
|
| 242 |
+
(48, 48, 28, 1, 59.99, "delivered", "2024-07-30"),
|
| 243 |
+
(49, 49, 39, 10, 99.90, "delivered", "2024-08-03"),
|
| 244 |
+
(50, 50, 21, 2, 59.98, "delivered", "2024-08-07"),
|
| 245 |
+
]
|
| 246 |
+
|
| 247 |
+
_REVIEWS = [
|
| 248 |
+
(1, 1, 1, 5, "Excellent headphones, crystal clear sound!", "2024-01-15"),
|
| 249 |
+
(2, 2, 6, 4, "Beautiful dress, fits perfectly.", "2024-01-20"),
|
| 250 |
+
(3, 3, 21, 5, "Best Python book for beginners.", "2024-01-25"),
|
| 251 |
+
(4, 4, 11, 4, "Very elegant vase set.", "2024-01-30"),
|
| 252 |
+
(5, 5, 16, 3, "Decent protein powder, average taste.", "2024-02-05"),
|
| 253 |
+
(6, 6, 31, 5, "My kid loves this LEGO set!", "2024-02-10"),
|
| 254 |
+
(7, 7, 3, 5, "Incredibly useful USB hub.", "2024-02-15"),
|
| 255 |
+
(8, 8, 41, 5, "Smart watch exceeded expectations.", "2024-02-20"),
|
| 256 |
+
(9, 9, 26, 4, "Dog loves the orthopedic bed.", "2024-02-25"),
|
| 257 |
+
(10, 10, 17, 5, "Great resistance bands, very durable.", "2024-03-01"),
|
| 258 |
+
(11, 11, 22, 4, "Solid ML intro book.", "2024-03-06"),
|
| 259 |
+
(12, 12, 7, 3, "Chinos are OK, sizing runs small.", "2024-03-11"),
|
| 260 |
+
(13, 13, 18, 5, "Perfect yoga mat, non-slip is great.", "2024-03-16"),
|
| 261 |
+
(14, 14, 37, 4, "Smooth wireless mouse.", "2024-03-21"),
|
| 262 |
+
(15, 15, 44, 5, "Waterproof backpack is amazing.", "2024-03-26"),
|
| 263 |
+
(16, 16, 2, 4, "Laptop stand is sturdy and adjustable.", "2024-03-31"),
|
| 264 |
+
(17, 17, 49, 3, "Decent phone stand but wobbly.", "2024-04-05"),
|
| 265 |
+
(18, 18, 5, 5, "Best webcam I've ever used.", "2024-04-10"),
|
| 266 |
+
(19, 19, 12, 5, "Bamboo cutting board is beautiful.", "2024-04-15"),
|
| 267 |
+
(20, 20, 33, 4, "Fun strategy board game.", "2024-04-20"),
|
| 268 |
+
(21, 21, 9, 5, "Running shoes are so comfortable!", "2024-04-25"),
|
| 269 |
+
(22, 22, 14, 4, "Smart bulbs work well with app.", "2024-04-30"),
|
| 270 |
+
(23, 23, 43, 4, "Air purifier is quiet and effective.", "2024-05-05"),
|
| 271 |
+
(24, 24, 25, 5, "Beautiful story box set for kids.", "2024-05-10"),
|
| 272 |
+
(25, 25, 8, 4, "Leather wallet is high quality.", "2024-05-15"),
|
| 273 |
+
(26, 26, 4, 5, "Mechanical keyboard is a joy to type on.", "2024-05-20"),
|
| 274 |
+
(27, 27, 29, 4, "Bird cage is spacious and well-made.", "2024-05-25"),
|
| 275 |
+
(28, 28, 20, 5, "Water bottle keeps drinks cold all day.", "2024-05-30"),
|
| 276 |
+
(29, 29, 35, 4, "Great art kit for kids.", "2024-06-04"),
|
| 277 |
+
(30, 30, 46, 4, "Solid dumbbells, good grip.", "2024-06-09"),
|
| 278 |
+
(31, 1, 13, 5, "Scented candles smell amazing.", "2024-06-14"),
|
| 279 |
+
(32, 2, 36, 4, "Desk organizer keeps my workspace tidy.", "2024-06-19"),
|
| 280 |
+
(33, 3, 48, 5, "Graphic novel bundle is worth every penny.", "2024-06-24"),
|
| 281 |
+
(34, 4, 1, 4, "Good headphones, comfy for long sessions.", "2024-06-29"),
|
| 282 |
+
(35, 5, 24, 5, "Love these mystery novels!", "2024-07-04"),
|
| 283 |
+
(36, 6, 10, 4, "High waist yoga pants are flattering.", "2024-07-09"),
|
| 284 |
+
(37, 7, 42, 4, "Powerful blender, handles frozen fruit.", "2024-07-14"),
|
| 285 |
+
(38, 8, 27, 5, "Cat scratching post is well built.", "2024-07-19"),
|
| 286 |
+
(39, 9, 6, 4, "Floral dress is as pictured.", "2024-07-24"),
|
| 287 |
+
(40, 10, 41, 5, "Smart watch has excellent battery life.", "2024-07-29"),
|
| 288 |
+
(41, 11, 19, 2, "Tennis racket feels cheap for the price.", "2024-08-03"),
|
| 289 |
+
(42, 12, 34, 5, "Puzzle is a perfect family activity.", "2024-08-08"),
|
| 290 |
+
(43, 13, 23, 5, "Data structures book is very clear.", "2024-08-13"),
|
| 291 |
+
(44, 14, 47, 4, "Jump rope is fast and durable.", "2024-08-18"),
|
| 292 |
+
(45, 15, 15, 3, "Book stand is okay, a bit light.", "2024-08-23"),
|
| 293 |
+
(46, 16, 32, 5, "Remote control car is very fast!", "2024-08-28"),
|
| 294 |
+
(47, 17, 3, 4, "USB hub works great on MacBook.", "2024-09-02"),
|
| 295 |
+
(48, 18, 28, 4, "Fish tank kit is easy to set up.", "2024-09-07"),
|
| 296 |
+
(49, 19, 38, 5, "Premium notebook has great paper.", "2024-09-12"),
|
| 297 |
+
(50, 20, 21, 5, "Python handbook is my go-to reference.", "2024-09-17"),
|
| 298 |
+
]
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# βββ Public API βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 302 |
+
|
| 303 |
+
def get_db_path() -> Path:
|
| 304 |
+
return DB_PATH
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def ensure_seeded() -> bool:
|
| 308 |
+
"""
|
| 309 |
+
Create the database and populate seed data if not already done.
|
| 310 |
+
Returns True if seed was needed (first run), False if already seeded.
|
| 311 |
+
"""
|
| 312 |
+
_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 313 |
+
conn = sqlite3.connect(str(DB_PATH))
|
| 314 |
+
try:
|
| 315 |
+
conn.executescript(_DDL)
|
| 316 |
+
conn.commit()
|
| 317 |
+
|
| 318 |
+
count = conn.execute("SELECT COUNT(*) FROM users").fetchone()[0]
|
| 319 |
+
if count >= 50:
|
| 320 |
+
return False # Already seeded
|
| 321 |
+
|
| 322 |
+
conn.execute("DELETE FROM reviews")
|
| 323 |
+
conn.execute("DELETE FROM orders")
|
| 324 |
+
conn.execute("DELETE FROM products")
|
| 325 |
+
conn.execute("DELETE FROM users")
|
| 326 |
+
conn.execute("DELETE FROM sellers")
|
| 327 |
+
|
| 328 |
+
conn.executemany(
|
| 329 |
+
"INSERT OR REPLACE INTO sellers VALUES (?,?,?,?,?)", _SELLERS
|
| 330 |
+
)
|
| 331 |
+
conn.executemany(
|
| 332 |
+
"INSERT OR REPLACE INTO users VALUES (?,?,?,?,?)", _USERS
|
| 333 |
+
)
|
| 334 |
+
conn.executemany(
|
| 335 |
+
"INSERT OR REPLACE INTO products VALUES (?,?,?,?,?,?)", _PRODUCTS
|
| 336 |
+
)
|
| 337 |
+
conn.executemany(
|
| 338 |
+
"INSERT OR REPLACE INTO orders VALUES (?,?,?,?,?,?,?)", _ORDERS
|
| 339 |
+
)
|
| 340 |
+
conn.executemany(
|
| 341 |
+
"INSERT OR REPLACE INTO reviews VALUES (?,?,?,?,?,?)", _REVIEWS
|
| 342 |
+
)
|
| 343 |
+
conn.commit()
|
| 344 |
+
return True
|
| 345 |
+
finally:
|
| 346 |
+
conn.close()
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def get_schema_info() -> str:
|
| 350 |
+
"""
|
| 351 |
+
Return a concise textual schema summary for use in prompts.
|
| 352 |
+
"""
|
| 353 |
+
conn = sqlite3.connect(str(DB_PATH))
|
| 354 |
+
try:
|
| 355 |
+
lines = []
|
| 356 |
+
for table in ["sellers", "users", "products", "orders", "reviews"]:
|
| 357 |
+
info = conn.execute(f"PRAGMA table_info({table})").fetchall()
|
| 358 |
+
cols = ", ".join(
|
| 359 |
+
f"{col[1]} {col[2]}{'(PK)' if col[5] else ''}"
|
| 360 |
+
for col in info
|
| 361 |
+
)
|
| 362 |
+
row_count = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
|
| 363 |
+
lines.append(f"Table: {table} ({row_count} rows)\n Columns: {cols}")
|
| 364 |
+
return "\n\n".join(lines)
|
| 365 |
+
finally:
|
| 366 |
+
conn.close()
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def execute_query(sql: str) -> tuple[list[dict], str | None]:
|
| 370 |
+
"""
|
| 371 |
+
Execute a SQL query and return (rows, error_message).
|
| 372 |
+
rows is a list of dicts; error_message is None on success.
|
| 373 |
+
"""
|
| 374 |
+
conn = sqlite3.connect(str(DB_PATH))
|
| 375 |
+
conn.row_factory = sqlite3.Row
|
| 376 |
+
try:
|
| 377 |
+
cursor = conn.execute(sql)
|
| 378 |
+
rows = [dict(row) for row in cursor.fetchall()]
|
| 379 |
+
return rows, None
|
| 380 |
+
except sqlite3.Error as e:
|
| 381 |
+
return [], str(e)
|
| 382 |
+
finally:
|
| 383 |
+
conn.close()
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def get_table_stats() -> list[dict]:
|
| 387 |
+
"""Return [{name, rows}, ...] for all tables."""
|
| 388 |
+
conn = sqlite3.connect(str(DB_PATH))
|
| 389 |
+
try:
|
| 390 |
+
tables = ["sellers", "users", "products", "orders", "reviews"]
|
| 391 |
+
return [
|
| 392 |
+
{
|
| 393 |
+
"name": t,
|
| 394 |
+
"rows": conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()[0],
|
| 395 |
+
}
|
| 396 |
+
for t in tables
|
| 397 |
+
]
|
| 398 |
+
finally:
|
| 399 |
+
conn.close()
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def get_schema_graph() -> dict:
|
| 403 |
+
"""Return schema graph with tables, columns, and foreign keys."""
|
| 404 |
+
conn = sqlite3.connect(str(DB_PATH))
|
| 405 |
+
try:
|
| 406 |
+
tables = []
|
| 407 |
+
for table in ["sellers", "users", "products", "orders", "reviews"]:
|
| 408 |
+
info = conn.execute(f"PRAGMA table_info({table})").fetchall()
|
| 409 |
+
columns = [
|
| 410 |
+
{"name": col[1], "type": col[2], "pk": bool(col[5])}
|
| 411 |
+
for col in info
|
| 412 |
+
]
|
| 413 |
+
tables.append({"name": table, "columns": columns})
|
| 414 |
+
|
| 415 |
+
foreign_keys = []
|
| 416 |
+
for table in ["sellers", "users", "products", "orders", "reviews"]:
|
| 417 |
+
fks = conn.execute(f"PRAGMA foreign_key_list({table})").fetchall()
|
| 418 |
+
for fk in fks:
|
| 419 |
+
foreign_keys.append(
|
| 420 |
+
{
|
| 421 |
+
"from_table": table,
|
| 422 |
+
"from_col": fk[3],
|
| 423 |
+
"to_table": fk[2],
|
| 424 |
+
"to_col": fk[4],
|
| 425 |
+
}
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
return {"tables": tables, "foreign_keys": foreign_keys}
|
| 429 |
+
finally:
|
| 430 |
+
conn.close()
|
backend/env/sql_env.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQLAgentEnv β OpenEnv-compliant environment for SQL generation.
|
| 3 |
+
|
| 4 |
+
Observation β Action β (Observation, Reward) loop.
|
| 5 |
+
|
| 6 |
+
The step() function:
|
| 7 |
+
1. Selects a repair prompt based on action.repair_action
|
| 8 |
+
2. Calls the LLM (OpenAI-compatible) to generate/repair SQL
|
| 9 |
+
3. Executes SQL on the benchmark DB
|
| 10 |
+
4. Classifies any error
|
| 11 |
+
5. Computes reward via grader
|
| 12 |
+
6. Updates LinUCB bandit
|
| 13 |
+
7. Returns (new_observation, reward)
|
| 14 |
+
|
| 15 |
+
Environment variables:
|
| 16 |
+
API_BASE_URL β OpenAI-compatible base URL (default: https://api.openai.com/v1)
|
| 17 |
+
MODEL_NAME β model to use (default: gpt-4o-mini)
|
| 18 |
+
HF_TOKEN β bearer token / API key
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import asyncio
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
from typing import Optional, AsyncIterator
|
| 27 |
+
|
| 28 |
+
from openai import AsyncOpenAI
|
| 29 |
+
from pydantic import BaseModel
|
| 30 |
+
|
| 31 |
+
from env.database import ensure_seeded, get_schema_info, execute_query
|
| 32 |
+
from env.tasks import get_task, get_all_tasks, TASKS
|
| 33 |
+
from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME
|
| 34 |
+
from rl.error_classifier import classify_error, extract_offending_token
|
| 35 |
+
from rl.grader import GraderInput, compute_reward, compute_episode_reward
|
| 36 |
+
from rl.linucb import LinUCB
|
| 37 |
+
from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message
|
| 38 |
+
from rl.experience import record_episode
|
| 39 |
+
from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES
|
| 40 |
+
|
| 41 |
+
# βββ OpenEnv Models ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Observation(BaseModel):
|
| 45 |
+
question: str
|
| 46 |
+
schema_info: str
|
| 47 |
+
current_sql: Optional[str] = None
|
| 48 |
+
error_message: Optional[str] = None
|
| 49 |
+
error_class: Optional[str] = None
|
| 50 |
+
attempt_number: int = 0
|
| 51 |
+
max_attempts: int = 5
|
| 52 |
+
task_id: str
|
| 53 |
+
task_difficulty: str
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Action(BaseModel):
|
| 57 |
+
repair_action: str # one of 8 repair action names or "generate"
|
| 58 |
+
custom_sql: Optional[str] = None # optional direct SQL override
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class RewardInfo(BaseModel):
|
| 62 |
+
value: float
|
| 63 |
+
success: bool
|
| 64 |
+
done: bool
|
| 65 |
+
info: dict
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# βββ LLM Client ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 69 |
+
|
| 70 |
+
def _make_client() -> AsyncOpenAI:
|
| 71 |
+
return AsyncOpenAI(
|
| 72 |
+
api_key=os.environ.get("HF_TOKEN", ""),
|
| 73 |
+
base_url=os.environ.get("API_BASE_URL", "https://api.openai.com/v1"),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
_MODEL = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 78 |
+
|
| 79 |
+
BASE_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.
|
| 80 |
+
|
| 81 |
+
Rules:
|
| 82 |
+
- Output ONLY the SQL query, nothing else
|
| 83 |
+
- No markdown, no code fences, no explanation
|
| 84 |
+
- Use SQLite syntax
|
| 85 |
+
- Do not include semicolons at the end"""
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _clean_sql(raw: str) -> str:
|
| 89 |
+
"""Strip markdown code fences and extra whitespace."""
|
| 90 |
+
raw = raw.strip()
|
| 91 |
+
raw = re.sub(r"^```(?:sql)?\s*", "", raw, flags=re.IGNORECASE)
|
| 92 |
+
raw = re.sub(r"\s*```$", "", raw)
|
| 93 |
+
return raw.strip().rstrip(";")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
async def _call_llm(
|
| 97 |
+
system_prompt: str,
|
| 98 |
+
user_message: str,
|
| 99 |
+
stream: bool = False,
|
| 100 |
+
) -> AsyncIterator[str] | str:
|
| 101 |
+
"""Call the LLM and return the generated text."""
|
| 102 |
+
client = _make_client()
|
| 103 |
+
|
| 104 |
+
if stream:
|
| 105 |
+
async def _gen():
|
| 106 |
+
resp = await client.chat.completions.create(
|
| 107 |
+
model=_MODEL,
|
| 108 |
+
messages=[
|
| 109 |
+
{"role": "system", "content": system_prompt},
|
| 110 |
+
{"role": "user", "content": user_message},
|
| 111 |
+
],
|
| 112 |
+
stream=True,
|
| 113 |
+
temperature=0.1,
|
| 114 |
+
)
|
| 115 |
+
async for chunk in resp:
|
| 116 |
+
delta = chunk.choices[0].delta.content
|
| 117 |
+
if delta:
|
| 118 |
+
yield delta
|
| 119 |
+
return _gen()
|
| 120 |
+
else:
|
| 121 |
+
resp = await client.chat.completions.create(
|
| 122 |
+
model=_MODEL,
|
| 123 |
+
messages=[
|
| 124 |
+
{"role": "system", "content": system_prompt},
|
| 125 |
+
{"role": "user", "content": user_message},
|
| 126 |
+
],
|
| 127 |
+
temperature=0.1,
|
| 128 |
+
)
|
| 129 |
+
return resp.choices[0].message.content or ""
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# βββ Episode State ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 133 |
+
|
| 134 |
+
class _Episode:
|
| 135 |
+
def __init__(self, task_id: str, question_id: str, question: str) -> None:
|
| 136 |
+
self.task_id = task_id
|
| 137 |
+
self.question_id = question_id
|
| 138 |
+
self.question = question
|
| 139 |
+
self.attempt_number = 0
|
| 140 |
+
self.current_sql: Optional[str] = None
|
| 141 |
+
self.error_message: Optional[str] = None
|
| 142 |
+
self.error_class: Optional[str] = None
|
| 143 |
+
self.steps: list[EpisodeStep] = []
|
| 144 |
+
self.step_rewards: list[float] = []
|
| 145 |
+
self.previous_error_class = None
|
| 146 |
+
self.consecutive_same_error = 0
|
| 147 |
+
self.last_action: Optional[RepairAction] = None
|
| 148 |
+
self.current_rl_state: Optional[RLState] = None
|
| 149 |
+
self.current_features: Optional[list[float]] = None
|
| 150 |
+
self.done = False
|
| 151 |
+
self.success = False
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# βββ Main Environment Class βββββββββββββββββββββββββββββββββββββββ
|
| 155 |
+
|
| 156 |
+
class SQLAgentEnv:
|
| 157 |
+
"""
|
| 158 |
+
OpenEnv-compliant environment for SQL generation and repair.
|
| 159 |
+
One active episode at a time.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
MAX_ATTEMPTS = 5
|
| 163 |
+
|
| 164 |
+
def __init__(self) -> None:
|
| 165 |
+
ensure_seeded()
|
| 166 |
+
self._bandit = LinUCB()
|
| 167 |
+
self._episode: Optional[_Episode] = None
|
| 168 |
+
self._schema_info = get_schema_info()
|
| 169 |
+
|
| 170 |
+
def reset(self, task_id: str = "simple_queries") -> Observation:
|
| 171 |
+
"""Start a new episode, picking the first question of the task."""
|
| 172 |
+
if self._episode and self._episode.steps and not self._episode.done:
|
| 173 |
+
self._finalize_episode(success=False)
|
| 174 |
+
|
| 175 |
+
task = get_task(task_id)
|
| 176 |
+
question_obj = task.questions[0]
|
| 177 |
+
|
| 178 |
+
self._episode = _Episode(
|
| 179 |
+
task_id=task_id,
|
| 180 |
+
question_id=question_obj.id,
|
| 181 |
+
question=question_obj.question,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
return self._build_observation()
|
| 185 |
+
|
| 186 |
+
def reset_with_question(
|
| 187 |
+
self, task_id: str, question_id: str
|
| 188 |
+
) -> Observation:
|
| 189 |
+
"""Start a new episode for a specific question."""
|
| 190 |
+
if self._episode and self._episode.steps and not self._episode.done:
|
| 191 |
+
self._finalize_episode(success=False)
|
| 192 |
+
|
| 193 |
+
task = get_task(task_id)
|
| 194 |
+
question_obj = next(
|
| 195 |
+
(q for q in task.questions if q.id == question_id), task.questions[0]
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
self._episode = _Episode(
|
| 199 |
+
task_id=task_id,
|
| 200 |
+
question_id=question_obj.id,
|
| 201 |
+
question=question_obj.question,
|
| 202 |
+
)
|
| 203 |
+
return self._build_observation()
|
| 204 |
+
|
| 205 |
+
async def step(self, action: Action) -> tuple[Observation, RewardInfo]:
|
| 206 |
+
"""
|
| 207 |
+
Execute one step:
|
| 208 |
+
1. Generate/repair SQL via LLM
|
| 209 |
+
2. Execute SQL
|
| 210 |
+
3. Grade and reward
|
| 211 |
+
4. Update bandit
|
| 212 |
+
"""
|
| 213 |
+
if self._episode is None:
|
| 214 |
+
raise RuntimeError("Call reset() before step()")
|
| 215 |
+
if self._episode.done:
|
| 216 |
+
raise RuntimeError("Episode is done. Call reset() to start a new one.")
|
| 217 |
+
|
| 218 |
+
ep = self._episode
|
| 219 |
+
ep.attempt_number += 1
|
| 220 |
+
|
| 221 |
+
# ββ 1. Build prompt ββββββββββββββββββββββββββββββββββββββ
|
| 222 |
+
if action.custom_sql:
|
| 223 |
+
generated_sql = action.custom_sql
|
| 224 |
+
else:
|
| 225 |
+
generated_sql = await self._generate_sql(action, ep)
|
| 226 |
+
|
| 227 |
+
generated_sql = _clean_sql(generated_sql)
|
| 228 |
+
|
| 229 |
+
# ββ 2. Execute SQL βββββββββββββββββββββββββββββββββββββββ
|
| 230 |
+
rows, error = execute_query(generated_sql)
|
| 231 |
+
success = error is None and len(rows) > 0
|
| 232 |
+
|
| 233 |
+
# ββ 3. Grade βββββββββββββββββββββββββββββββββββββββββββββ
|
| 234 |
+
task = get_task(ep.task_id)
|
| 235 |
+
question_obj = next(q for q in task.questions if q.id == ep.question_id)
|
| 236 |
+
|
| 237 |
+
from env.tasks import grade_response
|
| 238 |
+
task_score = grade_response(
|
| 239 |
+
ep.task_id, ep.question_id, generated_sql, rows, error, ep.attempt_number
|
| 240 |
+
)
|
| 241 |
+
success = task_score >= 0.8
|
| 242 |
+
|
| 243 |
+
# ββ 4. RL state + reward βββββββββββββββββββββββββββββββββ
|
| 244 |
+
current_error_class = None
|
| 245 |
+
error_class_name = None
|
| 246 |
+
if error:
|
| 247 |
+
ec = classify_error(error)
|
| 248 |
+
current_error_class = ec
|
| 249 |
+
error_class_name = ERROR_CLASS_NAMES[ec]
|
| 250 |
+
|
| 251 |
+
error_changed = (
|
| 252 |
+
ep.previous_error_class is not None
|
| 253 |
+
and ep.previous_error_class != current_error_class
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
if ep.previous_error_class == current_error_class:
|
| 257 |
+
ep.consecutive_same_error += 1
|
| 258 |
+
else:
|
| 259 |
+
ep.consecutive_same_error = 1
|
| 260 |
+
|
| 261 |
+
rl_state = RLState(
|
| 262 |
+
error_class=current_error_class,
|
| 263 |
+
attempt_number=ep.attempt_number,
|
| 264 |
+
previous_action=ep.last_action,
|
| 265 |
+
error_changed=error_changed,
|
| 266 |
+
consecutive_same_error=ep.consecutive_same_error,
|
| 267 |
+
)
|
| 268 |
+
ep.current_rl_state = rl_state
|
| 269 |
+
ep.current_features = featurize(rl_state)
|
| 270 |
+
|
| 271 |
+
grader_in = GraderInput(
|
| 272 |
+
success=success,
|
| 273 |
+
attempt_number=ep.attempt_number,
|
| 274 |
+
current_error_class=current_error_class,
|
| 275 |
+
previous_error_class=ep.previous_error_class,
|
| 276 |
+
)
|
| 277 |
+
grader_out = compute_reward(grader_in)
|
| 278 |
+
|
| 279 |
+
if ep.current_rl_state and ep.current_features:
|
| 280 |
+
# Determine action index
|
| 281 |
+
if action.repair_action == "generate":
|
| 282 |
+
repair_action_enum = RepairAction.REWRITE_FULL
|
| 283 |
+
else:
|
| 284 |
+
repair_action_enum = REPAIR_ACTION_BY_NAME.get(
|
| 285 |
+
action.repair_action, RepairAction.REWRITE_FULL
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
step_obj = EpisodeStep(
|
| 289 |
+
state=ep.current_rl_state,
|
| 290 |
+
featurized=ep.current_features,
|
| 291 |
+
action=repair_action_enum,
|
| 292 |
+
reward=grader_out.reward,
|
| 293 |
+
error_message=error or "",
|
| 294 |
+
sql=generated_sql,
|
| 295 |
+
success=success,
|
| 296 |
+
)
|
| 297 |
+
ep.steps.append(step_obj)
|
| 298 |
+
|
| 299 |
+
ep.step_rewards.append(grader_out.reward)
|
| 300 |
+
ep.current_sql = generated_sql
|
| 301 |
+
ep.error_message = error
|
| 302 |
+
ep.error_class = error_class_name
|
| 303 |
+
ep.previous_error_class = current_error_class
|
| 304 |
+
|
| 305 |
+
# ββ 5. Done check ββββββββββββββββββββββββββββββββββββββββ
|
| 306 |
+
done = success or ep.attempt_number >= self.MAX_ATTEMPTS
|
| 307 |
+
|
| 308 |
+
if done:
|
| 309 |
+
self._finalize_episode(success=success)
|
| 310 |
+
ep.done = True
|
| 311 |
+
ep.success = success
|
| 312 |
+
|
| 313 |
+
obs = self._build_observation()
|
| 314 |
+
reward_info = RewardInfo(
|
| 315 |
+
value=grader_out.reward,
|
| 316 |
+
success=success,
|
| 317 |
+
done=done,
|
| 318 |
+
info={
|
| 319 |
+
"task_score": task_score,
|
| 320 |
+
"attempt": ep.attempt_number,
|
| 321 |
+
"breakdown": {
|
| 322 |
+
"base": grader_out.breakdown.base,
|
| 323 |
+
"attempt_penalty": grader_out.breakdown.attempt_penalty,
|
| 324 |
+
"severity_bonus": grader_out.breakdown.severity_bonus,
|
| 325 |
+
"change_bonus": grader_out.breakdown.change_bonus,
|
| 326 |
+
},
|
| 327 |
+
"rows": rows[:5] if rows else [],
|
| 328 |
+
"row_count": len(rows),
|
| 329 |
+
"sql": generated_sql,
|
| 330 |
+
},
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
return obs, reward_info
|
| 334 |
+
|
| 335 |
+
async def step_streaming(
|
| 336 |
+
self, action: Action
|
| 337 |
+
) -> AsyncIterator[dict]:
|
| 338 |
+
"""
|
| 339 |
+
Step with SSE-compatible event streaming.
|
| 340 |
+
Yields dicts representing stream events.
|
| 341 |
+
"""
|
| 342 |
+
if self._episode is None:
|
| 343 |
+
raise RuntimeError("Call reset() before step_streaming()")
|
| 344 |
+
|
| 345 |
+
ep = self._episode
|
| 346 |
+
ep.attempt_number += 1
|
| 347 |
+
|
| 348 |
+
yield {"type": "attempt_start", "attempt": ep.attempt_number}
|
| 349 |
+
|
| 350 |
+
# Generate SQL
|
| 351 |
+
if action.custom_sql:
|
| 352 |
+
generated_sql = action.custom_sql
|
| 353 |
+
yield {"type": "sql_complete", "sql": generated_sql}
|
| 354 |
+
else:
|
| 355 |
+
chunks = []
|
| 356 |
+
async for chunk in await self._generate_sql_streaming(action, ep):
|
| 357 |
+
chunks.append(chunk)
|
| 358 |
+
yield {"type": "sql_chunk", "chunk": chunk}
|
| 359 |
+
generated_sql = _clean_sql("".join(chunks))
|
| 360 |
+
yield {"type": "sql_complete", "sql": generated_sql}
|
| 361 |
+
|
| 362 |
+
yield {"type": "executing"}
|
| 363 |
+
|
| 364 |
+
rows, error = execute_query(generated_sql)
|
| 365 |
+
|
| 366 |
+
from env.tasks import grade_response
|
| 367 |
+
task_score = grade_response(
|
| 368 |
+
ep.task_id, ep.question_id, generated_sql, rows, error, ep.attempt_number
|
| 369 |
+
)
|
| 370 |
+
success = task_score >= 0.8
|
| 371 |
+
|
| 372 |
+
# RL processing
|
| 373 |
+
current_error_class = None
|
| 374 |
+
error_class_name = None
|
| 375 |
+
repair_action_enum = RepairAction.REWRITE_FULL
|
| 376 |
+
|
| 377 |
+
if action.repair_action != "generate":
|
| 378 |
+
repair_action_enum = REPAIR_ACTION_BY_NAME.get(
|
| 379 |
+
action.repair_action, RepairAction.REWRITE_FULL
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
if error:
|
| 383 |
+
ec = classify_error(error)
|
| 384 |
+
current_error_class = ec
|
| 385 |
+
error_class_name = ERROR_CLASS_NAMES[ec]
|
| 386 |
+
|
| 387 |
+
error_changed = (
|
| 388 |
+
ep.previous_error_class is not None
|
| 389 |
+
and ep.previous_error_class != current_error_class
|
| 390 |
+
)
|
| 391 |
+
if ep.previous_error_class == current_error_class:
|
| 392 |
+
ep.consecutive_same_error += 1
|
| 393 |
+
else:
|
| 394 |
+
ep.consecutive_same_error = 1
|
| 395 |
+
|
| 396 |
+
rl_state = RLState(
|
| 397 |
+
error_class=current_error_class,
|
| 398 |
+
attempt_number=ep.attempt_number,
|
| 399 |
+
previous_action=ep.last_action,
|
| 400 |
+
error_changed=error_changed,
|
| 401 |
+
consecutive_same_error=ep.consecutive_same_error,
|
| 402 |
+
)
|
| 403 |
+
ep.current_rl_state = rl_state
|
| 404 |
+
ep.current_features = featurize(rl_state)
|
| 405 |
+
|
| 406 |
+
_, scores = self._bandit.select_action(ep.current_features)
|
| 407 |
+
ucb_scores = {
|
| 408 |
+
REPAIR_ACTION_NAMES[RepairAction(i)]: round(scores[i], 4)
|
| 409 |
+
for i in range(len(scores))
|
| 410 |
+
}
|
| 411 |
+
yield {
|
| 412 |
+
"type": "rl_action",
|
| 413 |
+
"action": REPAIR_ACTION_NAMES[repair_action_enum],
|
| 414 |
+
"ucb_scores": ucb_scores,
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
yield {"type": "error", "error": error, "error_class": error_class_name}
|
| 418 |
+
|
| 419 |
+
grader_in = GraderInput(
|
| 420 |
+
success=success,
|
| 421 |
+
attempt_number=ep.attempt_number,
|
| 422 |
+
current_error_class=current_error_class,
|
| 423 |
+
previous_error_class=ep.previous_error_class,
|
| 424 |
+
)
|
| 425 |
+
grader_out = compute_reward(grader_in)
|
| 426 |
+
|
| 427 |
+
if ep.current_rl_state and ep.current_features:
|
| 428 |
+
step_obj = EpisodeStep(
|
| 429 |
+
state=ep.current_rl_state,
|
| 430 |
+
featurized=ep.current_features,
|
| 431 |
+
action=repair_action_enum,
|
| 432 |
+
reward=grader_out.reward,
|
| 433 |
+
error_message=error or "",
|
| 434 |
+
sql=generated_sql,
|
| 435 |
+
success=success,
|
| 436 |
+
)
|
| 437 |
+
ep.steps.append(step_obj)
|
| 438 |
+
self._bandit.update(ep.current_features, repair_action_enum, grader_out.reward)
|
| 439 |
+
|
| 440 |
+
ep.step_rewards.append(grader_out.reward)
|
| 441 |
+
ep.current_sql = generated_sql
|
| 442 |
+
ep.error_message = error
|
| 443 |
+
ep.error_class = error_class_name
|
| 444 |
+
ep.previous_error_class = current_error_class
|
| 445 |
+
|
| 446 |
+
yield {
|
| 447 |
+
"type": "rl_reward",
|
| 448 |
+
"reward": grader_out.reward,
|
| 449 |
+
"breakdown": {
|
| 450 |
+
"base": grader_out.breakdown.base,
|
| 451 |
+
"attempt_penalty": grader_out.breakdown.attempt_penalty,
|
| 452 |
+
"severity_bonus": grader_out.breakdown.severity_bonus,
|
| 453 |
+
"change_bonus": grader_out.breakdown.change_bonus,
|
| 454 |
+
},
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
done = success or ep.attempt_number >= self.MAX_ATTEMPTS
|
| 458 |
+
|
| 459 |
+
if success:
|
| 460 |
+
yield {
|
| 461 |
+
"type": "success",
|
| 462 |
+
"rows": rows,
|
| 463 |
+
"row_count": len(rows),
|
| 464 |
+
"sql": generated_sql,
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
if done:
|
| 468 |
+
total_reward = compute_episode_reward(ep.step_rewards, success)
|
| 469 |
+
self._finalize_episode(success=success)
|
| 470 |
+
ep.done = True
|
| 471 |
+
ep.success = success
|
| 472 |
+
yield {
|
| 473 |
+
"type": "rl_episode_end",
|
| 474 |
+
"total_reward": total_reward,
|
| 475 |
+
"success": success,
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
def state(self) -> dict:
|
| 479 |
+
if self._episode is None:
|
| 480 |
+
return {"active": False}
|
| 481 |
+
ep = self._episode
|
| 482 |
+
return {
|
| 483 |
+
"active": True,
|
| 484 |
+
"task_id": ep.task_id,
|
| 485 |
+
"question_id": ep.question_id,
|
| 486 |
+
"question": ep.question,
|
| 487 |
+
"attempt_number": ep.attempt_number,
|
| 488 |
+
"max_attempts": self.MAX_ATTEMPTS,
|
| 489 |
+
"current_sql": ep.current_sql,
|
| 490 |
+
"error_message": ep.error_message,
|
| 491 |
+
"error_class": ep.error_class,
|
| 492 |
+
"done": ep.done,
|
| 493 |
+
"success": ep.success,
|
| 494 |
+
"step_rewards": ep.step_rewards,
|
| 495 |
+
"total_reward": compute_episode_reward(ep.step_rewards, ep.success),
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
# βββ Private Helpers ββββββββββββββββββββββββββββββββββββββββββ
|
| 499 |
+
|
| 500 |
+
def _build_observation(self) -> Observation:
|
| 501 |
+
if self._episode is None:
|
| 502 |
+
raise RuntimeError("No active episode")
|
| 503 |
+
ep = self._episode
|
| 504 |
+
task = get_task(ep.task_id)
|
| 505 |
+
return Observation(
|
| 506 |
+
question=ep.question,
|
| 507 |
+
schema_info=self._schema_info,
|
| 508 |
+
current_sql=ep.current_sql,
|
| 509 |
+
error_message=ep.error_message,
|
| 510 |
+
error_class=ep.error_class,
|
| 511 |
+
attempt_number=ep.attempt_number,
|
| 512 |
+
max_attempts=self.MAX_ATTEMPTS,
|
| 513 |
+
task_id=ep.task_id,
|
| 514 |
+
task_difficulty=task.difficulty,
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
async def _generate_sql(self, action: Action, ep: _Episode) -> str:
|
| 518 |
+
if action.repair_action == "generate" or ep.current_sql is None:
|
| 519 |
+
system = BASE_SYSTEM_PROMPT
|
| 520 |
+
user = (
|
| 521 |
+
f"Schema:\n{self._schema_info}\n\n"
|
| 522 |
+
f"Question: {ep.question}\n\n"
|
| 523 |
+
"Write a SQL query to answer this question."
|
| 524 |
+
)
|
| 525 |
+
else:
|
| 526 |
+
repair_action_enum = REPAIR_ACTION_BY_NAME.get(
|
| 527 |
+
action.repair_action, RepairAction.REWRITE_FULL
|
| 528 |
+
)
|
| 529 |
+
suffix = get_repair_system_suffix(repair_action_enum)
|
| 530 |
+
offending_token = extract_offending_token(ep.error_message or "")
|
| 531 |
+
ctx = RepairContext(
|
| 532 |
+
schema=self._schema_info,
|
| 533 |
+
question=ep.question,
|
| 534 |
+
failing_sql=ep.current_sql or "",
|
| 535 |
+
error_message=ep.error_message or "",
|
| 536 |
+
offending_token=offending_token,
|
| 537 |
+
)
|
| 538 |
+
system = BASE_SYSTEM_PROMPT + suffix
|
| 539 |
+
user = build_repair_user_message(repair_action_enum, ctx)
|
| 540 |
+
|
| 541 |
+
result = await _call_llm(system, user, stream=False)
|
| 542 |
+
return result # type: ignore[return-value]
|
| 543 |
+
|
| 544 |
+
async def _generate_sql_streaming(
|
| 545 |
+
self, action: Action, ep: _Episode
|
| 546 |
+
) -> AsyncIterator[str]:
|
| 547 |
+
if action.repair_action == "generate" or ep.current_sql is None:
|
| 548 |
+
system = BASE_SYSTEM_PROMPT
|
| 549 |
+
user = (
|
| 550 |
+
f"Schema:\n{self._schema_info}\n\n"
|
| 551 |
+
f"Question: {ep.question}\n\n"
|
| 552 |
+
"Write a SQL query to answer this question."
|
| 553 |
+
)
|
| 554 |
+
else:
|
| 555 |
+
repair_action_enum = REPAIR_ACTION_BY_NAME.get(
|
| 556 |
+
action.repair_action, RepairAction.REWRITE_FULL
|
| 557 |
+
)
|
| 558 |
+
suffix = get_repair_system_suffix(repair_action_enum)
|
| 559 |
+
offending_token = extract_offending_token(ep.error_message or "")
|
| 560 |
+
ctx = RepairContext(
|
| 561 |
+
schema=self._schema_info,
|
| 562 |
+
question=ep.question,
|
| 563 |
+
failing_sql=ep.current_sql or "",
|
| 564 |
+
error_message=ep.error_message or "",
|
| 565 |
+
offending_token=offending_token,
|
| 566 |
+
)
|
| 567 |
+
system = BASE_SYSTEM_PROMPT + suffix
|
| 568 |
+
user = build_repair_user_message(repair_action_enum, ctx)
|
| 569 |
+
|
| 570 |
+
return await _call_llm(system, user, stream=True) # type: ignore[return-value]
|
| 571 |
+
|
| 572 |
+
def _finalize_episode(self, success: bool) -> None:
|
| 573 |
+
ep = self._episode
|
| 574 |
+
if ep is None or not ep.steps:
|
| 575 |
+
return
|
| 576 |
+
try:
|
| 577 |
+
episode_obj, relabeled = record_episode(ep.question, ep.steps, success)
|
| 578 |
+
for exp in relabeled:
|
| 579 |
+
self._bandit.update(exp.state, exp.action, exp.reward)
|
| 580 |
+
self._bandit.decay_alpha()
|
| 581 |
+
except Exception:
|
| 582 |
+
pass
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
# βββ Singleton instance βββββββββββββββββββββββββββββββββββββββββββ
|
| 586 |
+
|
| 587 |
+
_env_instance: Optional[SQLAgentEnv] = None
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def get_env() -> SQLAgentEnv:
|
| 591 |
+
global _env_instance
|
| 592 |
+
if _env_instance is None:
|
| 593 |
+
_env_instance = SQLAgentEnv()
|
| 594 |
+
return _env_instance
|
backend/env/tasks.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task definitions for the SQL agent benchmark.
|
| 3 |
+
|
| 4 |
+
Three difficulty tiers, each with 5 questions and a grader function.
|
| 5 |
+
|
| 6 |
+
Grader contract: grader(sql, rows, error, attempts) -> float in [0.0, 1.0]
|
| 7 |
+
- rows: list[dict] from the executed SQL (may be empty)
|
| 8 |
+
- error: str | None
|
| 9 |
+
- attempts: int (1-indexed count of attempts taken)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import re
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from typing import Callable, Optional
|
| 17 |
+
|
| 18 |
+
from env.database import execute_query
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# βββ Task Definitions βββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class TaskQuestion:
|
| 25 |
+
id: str
|
| 26 |
+
question: str
|
| 27 |
+
expected_columns: list[str] # at least these columns should appear
|
| 28 |
+
min_rows: int # minimum expected rows
|
| 29 |
+
max_rows: Optional[int] = None # None = no upper bound
|
| 30 |
+
hint_tables: list[str] = field(default_factory=list) # tables that must be touched
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class Task:
|
| 35 |
+
id: str
|
| 36 |
+
name: str
|
| 37 |
+
difficulty: str # "easy" | "medium" | "hard"
|
| 38 |
+
description: str
|
| 39 |
+
questions: list[TaskQuestion]
|
| 40 |
+
grader: Callable # grader(question, sql, rows, error, attempts) -> float
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# βββ Grader Helpers βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
|
| 45 |
+
def _has_required_columns(rows: list[dict], required: list[str]) -> bool:
|
| 46 |
+
if not rows:
|
| 47 |
+
return False
|
| 48 |
+
row_keys = {k.lower() for k in rows[0].keys()}
|
| 49 |
+
return all(col.lower() in row_keys for col in required)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _row_count_score(rows: list[dict], min_rows: int, max_rows: Optional[int]) -> float:
|
| 53 |
+
n = len(rows)
|
| 54 |
+
if n == 0:
|
| 55 |
+
return 0.0
|
| 56 |
+
if n >= min_rows:
|
| 57 |
+
if max_rows is None or n <= max_rows:
|
| 58 |
+
return 1.0
|
| 59 |
+
# Over the expected maximum β might be a missing WHERE clause
|
| 60 |
+
return 0.5
|
| 61 |
+
# Partial result
|
| 62 |
+
return 0.5 * (n / min_rows)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# βββ Task 1: Simple Queries (Easy) ββββββββββββββββββββββββββββββββ
|
| 66 |
+
|
| 67 |
+
_SIMPLE_QUESTIONS = [
|
| 68 |
+
TaskQuestion(
|
| 69 |
+
id="sq-01",
|
| 70 |
+
question="List all users from the USA.",
|
| 71 |
+
expected_columns=["name", "email", "country"],
|
| 72 |
+
min_rows=10,
|
| 73 |
+
max_rows=25,
|
| 74 |
+
hint_tables=["users"],
|
| 75 |
+
),
|
| 76 |
+
TaskQuestion(
|
| 77 |
+
id="sq-02",
|
| 78 |
+
question="Show all products in the 'Electronics' category with their prices.",
|
| 79 |
+
expected_columns=["name", "price"],
|
| 80 |
+
min_rows=8,
|
| 81 |
+
max_rows=20,
|
| 82 |
+
hint_tables=["products"],
|
| 83 |
+
),
|
| 84 |
+
TaskQuestion(
|
| 85 |
+
id="sq-03",
|
| 86 |
+
question="Find all orders with status 'delivered'.",
|
| 87 |
+
expected_columns=["id", "status"],
|
| 88 |
+
min_rows=30,
|
| 89 |
+
max_rows=50,
|
| 90 |
+
hint_tables=["orders"],
|
| 91 |
+
),
|
| 92 |
+
TaskQuestion(
|
| 93 |
+
id="sq-04",
|
| 94 |
+
question="List all sellers and their countries.",
|
| 95 |
+
expected_columns=["name", "country"],
|
| 96 |
+
min_rows=10,
|
| 97 |
+
max_rows=10,
|
| 98 |
+
hint_tables=["sellers"],
|
| 99 |
+
),
|
| 100 |
+
TaskQuestion(
|
| 101 |
+
id="sq-05",
|
| 102 |
+
question="Show all reviews with a rating of 5 stars.",
|
| 103 |
+
expected_columns=["rating"],
|
| 104 |
+
min_rows=15,
|
| 105 |
+
max_rows=35,
|
| 106 |
+
hint_tables=["reviews"],
|
| 107 |
+
),
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _grade_simple(
|
| 112 |
+
question: TaskQuestion,
|
| 113 |
+
sql: str,
|
| 114 |
+
rows: list[dict],
|
| 115 |
+
error: Optional[str],
|
| 116 |
+
attempts: int,
|
| 117 |
+
) -> float:
|
| 118 |
+
if error:
|
| 119 |
+
return 0.0
|
| 120 |
+
|
| 121 |
+
col_ok = _has_required_columns(rows, question.expected_columns)
|
| 122 |
+
row_score = _row_count_score(rows, question.min_rows, question.max_rows)
|
| 123 |
+
|
| 124 |
+
if col_ok and row_score == 1.0:
|
| 125 |
+
return 1.0
|
| 126 |
+
if col_ok or row_score >= 0.5:
|
| 127 |
+
return 0.5
|
| 128 |
+
return 0.0
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
_TASK_SIMPLE = Task(
|
| 132 |
+
id="simple_queries",
|
| 133 |
+
name="Simple Queries",
|
| 134 |
+
difficulty="easy",
|
| 135 |
+
description="Single-table SELECT queries with basic filters.",
|
| 136 |
+
questions=_SIMPLE_QUESTIONS,
|
| 137 |
+
grader=_grade_simple,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# βββ Task 2: Join Queries (Medium) ββββββββββββββββββββββββββββββββ
|
| 142 |
+
|
| 143 |
+
_JOIN_QUESTIONS = [
|
| 144 |
+
TaskQuestion(
|
| 145 |
+
id="jq-01",
|
| 146 |
+
question="Show the total number of orders per user, including the user's name.",
|
| 147 |
+
expected_columns=["name"],
|
| 148 |
+
min_rows=10,
|
| 149 |
+
hint_tables=["users", "orders"],
|
| 150 |
+
),
|
| 151 |
+
TaskQuestion(
|
| 152 |
+
id="jq-02",
|
| 153 |
+
question="List products along with the name of their seller.",
|
| 154 |
+
expected_columns=["name", "name"], # product name + seller name both called 'name'
|
| 155 |
+
min_rows=20,
|
| 156 |
+
hint_tables=["products", "sellers"],
|
| 157 |
+
),
|
| 158 |
+
TaskQuestion(
|
| 159 |
+
id="jq-03",
|
| 160 |
+
question="Find the average rating for each product category.",
|
| 161 |
+
expected_columns=["category"],
|
| 162 |
+
min_rows=5,
|
| 163 |
+
max_rows=10,
|
| 164 |
+
hint_tables=["products", "reviews"],
|
| 165 |
+
),
|
| 166 |
+
TaskQuestion(
|
| 167 |
+
id="jq-04",
|
| 168 |
+
question="Show the total revenue (sum of total_price) per seller.",
|
| 169 |
+
expected_columns=["name"],
|
| 170 |
+
min_rows=5,
|
| 171 |
+
hint_tables=["sellers", "products", "orders"],
|
| 172 |
+
),
|
| 173 |
+
TaskQuestion(
|
| 174 |
+
id="jq-05",
|
| 175 |
+
question="List the top 5 most reviewed products with their review counts.",
|
| 176 |
+
expected_columns=["name"],
|
| 177 |
+
min_rows=5,
|
| 178 |
+
max_rows=5,
|
| 179 |
+
hint_tables=["products", "reviews"],
|
| 180 |
+
),
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _grade_join(
|
| 185 |
+
question: TaskQuestion,
|
| 186 |
+
sql: str,
|
| 187 |
+
rows: list[dict],
|
| 188 |
+
error: Optional[str],
|
| 189 |
+
attempts: int,
|
| 190 |
+
) -> float:
|
| 191 |
+
if error:
|
| 192 |
+
return 0.0
|
| 193 |
+
|
| 194 |
+
col_ok = _has_required_columns(rows, [question.expected_columns[0]])
|
| 195 |
+
row_score = _row_count_score(rows, question.min_rows, question.max_rows)
|
| 196 |
+
|
| 197 |
+
base = 0.0
|
| 198 |
+
if col_ok and row_score == 1.0:
|
| 199 |
+
base = 1.0
|
| 200 |
+
elif col_ok or row_score >= 0.5:
|
| 201 |
+
base = 0.5
|
| 202 |
+
|
| 203 |
+
# Penalize extra attempts
|
| 204 |
+
attempt_penalty = max(0.0, 0.1 * (attempts - 1))
|
| 205 |
+
return max(0.0, base - attempt_penalty)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
_TASK_JOIN = Task(
|
| 209 |
+
id="join_queries",
|
| 210 |
+
name="Join Queries",
|
| 211 |
+
difficulty="medium",
|
| 212 |
+
description="Multi-table JOINs with GROUP BY and aggregation.",
|
| 213 |
+
questions=_JOIN_QUESTIONS,
|
| 214 |
+
grader=_grade_join,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# βββ Task 3: Complex Queries (Hard) βββββββββββββββββββββββββββββββ
|
| 219 |
+
|
| 220 |
+
_COMPLEX_QUESTIONS = [
|
| 221 |
+
TaskQuestion(
|
| 222 |
+
id="cq-01",
|
| 223 |
+
question=(
|
| 224 |
+
"Find users who have placed more than 1 order, showing their name "
|
| 225 |
+
"and total number of orders, ordered by order count descending."
|
| 226 |
+
),
|
| 227 |
+
expected_columns=["name"],
|
| 228 |
+
min_rows=1,
|
| 229 |
+
hint_tables=["users", "orders"],
|
| 230 |
+
),
|
| 231 |
+
TaskQuestion(
|
| 232 |
+
id="cq-02",
|
| 233 |
+
question=(
|
| 234 |
+
"For each product category, show the category name, number of products, "
|
| 235 |
+
"average price, and total stock. Use a CTE."
|
| 236 |
+
),
|
| 237 |
+
expected_columns=["category"],
|
| 238 |
+
min_rows=5,
|
| 239 |
+
max_rows=10,
|
| 240 |
+
hint_tables=["products"],
|
| 241 |
+
),
|
| 242 |
+
TaskQuestion(
|
| 243 |
+
id="cq-03",
|
| 244 |
+
question=(
|
| 245 |
+
"Show each seller's name, their total sales revenue, and rank them "
|
| 246 |
+
"by revenue using a window function (RANK() or ROW_NUMBER())."
|
| 247 |
+
),
|
| 248 |
+
expected_columns=["name"],
|
| 249 |
+
min_rows=5,
|
| 250 |
+
hint_tables=["sellers", "products", "orders"],
|
| 251 |
+
),
|
| 252 |
+
TaskQuestion(
|
| 253 |
+
id="cq-04",
|
| 254 |
+
question=(
|
| 255 |
+
"Find the top-rated product in each category (highest average review rating). "
|
| 256 |
+
"Show category, product name, and average rating."
|
| 257 |
+
),
|
| 258 |
+
expected_columns=["category", "name"],
|
| 259 |
+
min_rows=5,
|
| 260 |
+
max_rows=10,
|
| 261 |
+
hint_tables=["products", "reviews"],
|
| 262 |
+
),
|
| 263 |
+
TaskQuestion(
|
| 264 |
+
id="cq-05",
|
| 265 |
+
question=(
|
| 266 |
+
"Calculate the month-over-month order count for 2024, showing year, "
|
| 267 |
+
"month, order_count, and a running total."
|
| 268 |
+
),
|
| 269 |
+
expected_columns=["month"],
|
| 270 |
+
min_rows=6,
|
| 271 |
+
max_rows=12,
|
| 272 |
+
hint_tables=["orders"],
|
| 273 |
+
),
|
| 274 |
+
]
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _grade_complex(
|
| 278 |
+
question: TaskQuestion,
|
| 279 |
+
sql: str,
|
| 280 |
+
rows: list[dict],
|
| 281 |
+
error: Optional[str],
|
| 282 |
+
attempts: int,
|
| 283 |
+
) -> float:
|
| 284 |
+
if error:
|
| 285 |
+
return 0.0
|
| 286 |
+
|
| 287 |
+
col_ok = _has_required_columns(rows, question.expected_columns)
|
| 288 |
+
row_score = _row_count_score(rows, question.min_rows, question.max_rows)
|
| 289 |
+
|
| 290 |
+
if not col_ok or row_score == 0.0:
|
| 291 |
+
return 0.0
|
| 292 |
+
|
| 293 |
+
# Hard task base max is 0.8 unless first-attempt bonus
|
| 294 |
+
if row_score == 1.0 and col_ok:
|
| 295 |
+
base = 0.8 + (0.2 if attempts == 1 else 0.0)
|
| 296 |
+
else:
|
| 297 |
+
base = 0.4 # partial
|
| 298 |
+
|
| 299 |
+
# Strict attempt penalty for hard queries
|
| 300 |
+
attempt_penalty = 0.1 * (attempts - 1)
|
| 301 |
+
return max(0.0, base - attempt_penalty)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
_TASK_COMPLEX = Task(
|
| 305 |
+
id="complex_queries",
|
| 306 |
+
name="Complex Queries",
|
| 307 |
+
difficulty="hard",
|
| 308 |
+
description="CTEs, window functions, and nested aggregations.",
|
| 309 |
+
questions=_COMPLEX_QUESTIONS,
|
| 310 |
+
grader=_grade_complex,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# βββ Registry βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 315 |
+
|
| 316 |
+
TASKS: dict[str, Task] = {
|
| 317 |
+
"simple_queries": _TASK_SIMPLE,
|
| 318 |
+
"join_queries": _TASK_JOIN,
|
| 319 |
+
"complex_queries": _TASK_COMPLEX,
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def get_task(task_id: str) -> Task:
|
| 324 |
+
if task_id not in TASKS:
|
| 325 |
+
raise ValueError(f"Unknown task_id: {task_id!r}. Valid: {list(TASKS)}")
|
| 326 |
+
return TASKS[task_id]
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def get_all_tasks() -> list[Task]:
|
| 330 |
+
return list(TASKS.values())
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def grade_response(
|
| 334 |
+
task_id: str,
|
| 335 |
+
question_id: str,
|
| 336 |
+
sql: str,
|
| 337 |
+
rows: list[dict],
|
| 338 |
+
error: Optional[str],
|
| 339 |
+
attempts: int,
|
| 340 |
+
) -> float:
|
| 341 |
+
task = get_task(task_id)
|
| 342 |
+
question = next((q for q in task.questions if q.id == question_id), None)
|
| 343 |
+
if question is None:
|
| 344 |
+
raise ValueError(f"Unknown question_id {question_id!r} in task {task_id!r}")
|
| 345 |
+
return task.grader(question, sql, rows, error, attempts)
|
backend/gepa/__init__.py
ADDED
|
File without changes
|
backend/gepa/optimizer.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GEPA (Goal-directed Evolutionary Prompt Adaptation) optimizer.
|
| 3 |
+
|
| 4 |
+
Ported from gepa.ts. Key steps:
|
| 5 |
+
1. Reflection: LLM analyzes failure history, outputs diagnosis
|
| 6 |
+
2. Mutation: LLM rewrites system prompt based on diagnosis
|
| 7 |
+
3. Scoring: Run 3 golden queries with new prompt, compute score
|
| 8 |
+
4. Pareto front: Keep top 3 prompts by (score, diversity)
|
| 9 |
+
|
| 10 |
+
State is persisted to data/gepa_prompt.json.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import time
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from openai import AsyncOpenAI
|
| 22 |
+
from pydantic import BaseModel
|
| 23 |
+
|
| 24 |
+
_DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
|
| 25 |
+
GEPA_PATH = _DATA_DIR / "gepa_prompt.json"
|
| 26 |
+
|
| 27 |
+
_MODEL = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 28 |
+
|
| 29 |
+
SEED_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.
|
| 30 |
+
|
| 31 |
+
Rules:
|
| 32 |
+
- Output ONLY the SQL query, nothing else
|
| 33 |
+
- No markdown, no code fences, no explanation
|
| 34 |
+
- Use SQLite syntax"""
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# βββ Models ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
|
| 39 |
+
class QueryResult(BaseModel):
|
| 40 |
+
question: str
|
| 41 |
+
final_sql: str
|
| 42 |
+
attempts: int
|
| 43 |
+
success: bool
|
| 44 |
+
errors: list[str]
|
| 45 |
+
timestamp: float
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Candidate(BaseModel):
|
| 49 |
+
system_prompt: str
|
| 50 |
+
score: float
|
| 51 |
+
avg_attempts: float
|
| 52 |
+
success_rate: float
|
| 53 |
+
generation: int
|
| 54 |
+
feedback: list[str]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# βββ LLM Helper ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 58 |
+
|
| 59 |
+
def _make_client() -> AsyncOpenAI:
|
| 60 |
+
return AsyncOpenAI(
|
| 61 |
+
api_key=os.environ.get("HF_TOKEN", ""),
|
| 62 |
+
base_url=os.environ.get("API_BASE_URL", "https://api.openai.com/v1"),
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
async def _complete(system: str, user: str) -> str:
|
| 67 |
+
client = _make_client()
|
| 68 |
+
resp = await client.chat.completions.create(
|
| 69 |
+
model=_MODEL,
|
| 70 |
+
messages=[
|
| 71 |
+
{"role": "system", "content": system},
|
| 72 |
+
{"role": "user", "content": user},
|
| 73 |
+
],
|
| 74 |
+
temperature=0.7,
|
| 75 |
+
)
|
| 76 |
+
return resp.choices[0].message.content or ""
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# βββ Golden Queries for Scoring ββββββββββββββββββββββββββββββββββ
|
| 80 |
+
|
| 81 |
+
_GOLDEN_QUERIES = [
|
| 82 |
+
{
|
| 83 |
+
"id": "gq-01",
|
| 84 |
+
"question": "List all users from the USA.",
|
| 85 |
+
"expected_min_rows": 10,
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"id": "gq-02",
|
| 89 |
+
"question": "Show all products in the 'Electronics' category.",
|
| 90 |
+
"expected_min_rows": 8,
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"id": "gq-03",
|
| 94 |
+
"question": "Find the total number of orders per user.",
|
| 95 |
+
"expected_min_rows": 10,
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"id": "gq-04",
|
| 99 |
+
"question": "Show the average rating for each product category.",
|
| 100 |
+
"expected_min_rows": 5,
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"id": "gq-05",
|
| 104 |
+
"question": "List products along with their seller name.",
|
| 105 |
+
"expected_min_rows": 20,
|
| 106 |
+
},
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# βββ Optimizer Class ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 111 |
+
|
| 112 |
+
class GEPAOptimizer:
|
| 113 |
+
def __init__(self) -> None:
|
| 114 |
+
self._history: list[QueryResult] = []
|
| 115 |
+
self._pareto_front: list[Candidate] = [
|
| 116 |
+
Candidate(
|
| 117 |
+
system_prompt=SEED_SYSTEM_PROMPT,
|
| 118 |
+
score=0.5,
|
| 119 |
+
avg_attempts=3.0,
|
| 120 |
+
success_rate=0.5,
|
| 121 |
+
generation=0,
|
| 122 |
+
feedback=[],
|
| 123 |
+
)
|
| 124 |
+
]
|
| 125 |
+
self._load()
|
| 126 |
+
|
| 127 |
+
# βββ Public Interface βββββββββββββββββββββββββββββββββββββββββ
|
| 128 |
+
|
| 129 |
+
def record_result(self, result: QueryResult) -> None:
|
| 130 |
+
self._history.append(result)
|
| 131 |
+
self._save()
|
| 132 |
+
|
| 133 |
+
def get_current_prompt(self) -> str:
|
| 134 |
+
if not self._pareto_front:
|
| 135 |
+
return SEED_SYSTEM_PROMPT
|
| 136 |
+
return max(self._pareto_front, key=lambda c: c.score).system_prompt
|
| 137 |
+
|
| 138 |
+
def get_history(self) -> list[QueryResult]:
|
| 139 |
+
return list(self._history)
|
| 140 |
+
|
| 141 |
+
def get_pareto_front(self) -> list[Candidate]:
|
| 142 |
+
return list(self._pareto_front)
|
| 143 |
+
|
| 144 |
+
def set_current_prompt(self, prompt: str) -> None:
|
| 145 |
+
if self._pareto_front:
|
| 146 |
+
best = max(self._pareto_front, key=lambda c: c.score)
|
| 147 |
+
best.system_prompt = prompt
|
| 148 |
+
else:
|
| 149 |
+
self._pareto_front.append(
|
| 150 |
+
Candidate(
|
| 151 |
+
system_prompt=prompt,
|
| 152 |
+
score=0.5,
|
| 153 |
+
avg_attempts=3.0,
|
| 154 |
+
success_rate=0.5,
|
| 155 |
+
generation=0,
|
| 156 |
+
feedback=[],
|
| 157 |
+
)
|
| 158 |
+
)
|
| 159 |
+
self._save()
|
| 160 |
+
|
| 161 |
+
def should_optimize(self) -> bool:
|
| 162 |
+
return len(self._history) > 0 and len(self._history) % 4 == 0
|
| 163 |
+
|
| 164 |
+
def reset(self) -> None:
|
| 165 |
+
self._history.clear()
|
| 166 |
+
self._pareto_front.clear()
|
| 167 |
+
self._pareto_front.append(
|
| 168 |
+
Candidate(
|
| 169 |
+
system_prompt=SEED_SYSTEM_PROMPT,
|
| 170 |
+
score=0.5,
|
| 171 |
+
avg_attempts=3.0,
|
| 172 |
+
success_rate=0.5,
|
| 173 |
+
generation=0,
|
| 174 |
+
feedback=[],
|
| 175 |
+
)
|
| 176 |
+
)
|
| 177 |
+
self._save()
|
| 178 |
+
|
| 179 |
+
async def run_optimization_cycle(
|
| 180 |
+
self,
|
| 181 |
+
user_feedback_context: Optional[str] = None,
|
| 182 |
+
dialect: str = "SQLite",
|
| 183 |
+
) -> Optional[dict]:
|
| 184 |
+
"""
|
| 185 |
+
Run one GEPA cycle: reflect β mutate β score β update Pareto front.
|
| 186 |
+
Returns {new_prompt, reflection} or None if not enough data.
|
| 187 |
+
"""
|
| 188 |
+
if len(self._history) < 2:
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
recent_failures = [
|
| 192 |
+
h for h in self._history if h.attempts > 1 or not h.success
|
| 193 |
+
][-8:]
|
| 194 |
+
if len(recent_failures) < 2:
|
| 195 |
+
return None
|
| 196 |
+
|
| 197 |
+
current_best = self.get_current_prompt()
|
| 198 |
+
|
| 199 |
+
# ββ Step 1: Reflect ββββββββββββββββββββββββββββββββββββββ
|
| 200 |
+
failure_summary = "\n\n---\n\n".join(
|
| 201 |
+
f'Query {i+1}: "{f.question}"\n'
|
| 202 |
+
f"Attempts: {f.attempts}\n"
|
| 203 |
+
f"Errors:\n" + "\n".join(f" - {e}" for e in f.errors) + "\n"
|
| 204 |
+
f"Final SQL: {f.final_sql}"
|
| 205 |
+
for i, f in enumerate(recent_failures)
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
user_ctx_block = (
|
| 209 |
+
f"\n\nUser conversation:\n{user_feedback_context}"
|
| 210 |
+
if user_feedback_context
|
| 211 |
+
else ""
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
reflection = await _complete(
|
| 215 |
+
f"You are an expert SQL prompt engineer analyzing why an LLM SQL agent is failing.\n"
|
| 216 |
+
f"The target database is {dialect} β all rules must use {dialect} syntax.\n"
|
| 217 |
+
"Your job: identify specific, recurring patterns in these failures and state EXACTLY "
|
| 218 |
+
"what rules or knowledge the system prompt is missing.\n"
|
| 219 |
+
"Be very specific β name the exact functions, syntax patterns, or schema reasoning gaps.\n"
|
| 220 |
+
"Output a concise diagnosis (3-5 bullet points max).",
|
| 221 |
+
f"Current system prompt:\n{current_best}\n\n"
|
| 222 |
+
f"Recent failures:\n{failure_summary}{user_ctx_block}",
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# ββ Step 2: Mutate βββββββββββββββββββββββββββββββββββββββ
|
| 226 |
+
current_generation = max(c.generation for c in self._pareto_front) if self._pareto_front else 0
|
| 227 |
+
|
| 228 |
+
new_prompt = await _complete(
|
| 229 |
+
f"You are an expert prompt engineer. Improve a system prompt for a {dialect} SQL generation agent.\n"
|
| 230 |
+
"Rules for the new prompt:\n"
|
| 231 |
+
"- Keep it concise and actionable\n"
|
| 232 |
+
f"- The target database is {dialect} β use ONLY {dialect} syntax and functions\n"
|
| 233 |
+
"- Add specific rules that address the diagnosed failure patterns\n"
|
| 234 |
+
"- Do NOT add generic fluff β every rule must be earned by a real failure\n"
|
| 235 |
+
"- Output ONLY the improved system prompt text, nothing else",
|
| 236 |
+
f"Current system prompt:\n{current_best}\n\n"
|
| 237 |
+
f"Diagnosed failure patterns:\n{reflection}\n\n"
|
| 238 |
+
"Write the improved system prompt:",
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# ββ Step 3: Score ββββββββββββββββββββββββββββββββββββββββ
|
| 242 |
+
benchmark_score = await self._score_prompt(new_prompt)
|
| 243 |
+
|
| 244 |
+
current_avg_attempts = (
|
| 245 |
+
sum(h.attempts for h in self._history) / len(self._history)
|
| 246 |
+
if self._history
|
| 247 |
+
else 3.0
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
new_candidate = Candidate(
|
| 251 |
+
system_prompt=new_prompt,
|
| 252 |
+
score=benchmark_score,
|
| 253 |
+
avg_attempts=max(current_avg_attempts - 0.5, 1.0),
|
| 254 |
+
success_rate=benchmark_score,
|
| 255 |
+
generation=current_generation + 1,
|
| 256 |
+
feedback=[reflection],
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# ββ Step 4: Update Pareto front ββββββββββββββββββββββββββ
|
| 260 |
+
self._pareto_front.append(new_candidate)
|
| 261 |
+
self._pareto_front.sort(key=lambda c: c.score, reverse=True)
|
| 262 |
+
if len(self._pareto_front) > 3:
|
| 263 |
+
self._pareto_front = self._pareto_front[:3]
|
| 264 |
+
|
| 265 |
+
self._save()
|
| 266 |
+
return {"new_prompt": new_prompt, "reflection": reflection}
|
| 267 |
+
|
| 268 |
+
async def _score_prompt(self, prompt: str) -> float:
|
| 269 |
+
"""
|
| 270 |
+
Score a prompt by running 3 golden queries and measuring success rate.
|
| 271 |
+
"""
|
| 272 |
+
from env.database import execute_query, get_schema_info
|
| 273 |
+
import re
|
| 274 |
+
|
| 275 |
+
schema = get_schema_info()
|
| 276 |
+
client = _make_client()
|
| 277 |
+
|
| 278 |
+
scores = []
|
| 279 |
+
for gq in _GOLDEN_QUERIES[:3]:
|
| 280 |
+
try:
|
| 281 |
+
resp = await client.chat.completions.create(
|
| 282 |
+
model=_MODEL,
|
| 283 |
+
messages=[
|
| 284 |
+
{"role": "system", "content": prompt},
|
| 285 |
+
{
|
| 286 |
+
"role": "user",
|
| 287 |
+
"content": (
|
| 288 |
+
f"Schema:\n{schema}\n\n"
|
| 289 |
+
f"Question: {gq['question']}\n\n"
|
| 290 |
+
"Write a SQL query."
|
| 291 |
+
),
|
| 292 |
+
},
|
| 293 |
+
],
|
| 294 |
+
temperature=0.1,
|
| 295 |
+
)
|
| 296 |
+
sql = resp.choices[0].message.content or ""
|
| 297 |
+
sql = re.sub(r"^```(?:sql)?\s*", "", sql.strip(), flags=re.IGNORECASE)
|
| 298 |
+
sql = re.sub(r"\s*```$", "", sql).strip().rstrip(";")
|
| 299 |
+
|
| 300 |
+
rows, error = execute_query(sql)
|
| 301 |
+
if error is None and len(rows) >= gq["expected_min_rows"]:
|
| 302 |
+
scores.append(1.0)
|
| 303 |
+
elif error is None and rows:
|
| 304 |
+
scores.append(0.5)
|
| 305 |
+
else:
|
| 306 |
+
scores.append(0.0)
|
| 307 |
+
except Exception:
|
| 308 |
+
scores.append(0.0)
|
| 309 |
+
|
| 310 |
+
return sum(scores) / len(scores) if scores else 0.3
|
| 311 |
+
|
| 312 |
+
# βββ Persistence βββββββββββββββββββββββββββββββββββββββββββββ
|
| 313 |
+
|
| 314 |
+
def _save(self) -> None:
|
| 315 |
+
try:
|
| 316 |
+
GEPA_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 317 |
+
data = {
|
| 318 |
+
"history": [r.model_dump() for r in self._history[-100:]],
|
| 319 |
+
"pareto_front": [c.model_dump() for c in self._pareto_front],
|
| 320 |
+
}
|
| 321 |
+
GEPA_PATH.write_text(json.dumps(data, default=str))
|
| 322 |
+
except Exception:
|
| 323 |
+
pass
|
| 324 |
+
|
| 325 |
+
def _load(self) -> None:
|
| 326 |
+
try:
|
| 327 |
+
if not GEPA_PATH.exists():
|
| 328 |
+
return
|
| 329 |
+
data = json.loads(GEPA_PATH.read_text())
|
| 330 |
+
self._history = [QueryResult(**r) for r in data.get("history", [])]
|
| 331 |
+
loaded_front = [Candidate(**c) for c in data.get("pareto_front", [])]
|
| 332 |
+
if loaded_front:
|
| 333 |
+
self._pareto_front = loaded_front
|
| 334 |
+
except Exception:
|
| 335 |
+
pass
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# βββ Singleton ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 339 |
+
|
| 340 |
+
_gepa_instance: Optional[GEPAOptimizer] = None
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def get_gepa() -> GEPAOptimizer:
|
| 344 |
+
global _gepa_instance
|
| 345 |
+
if _gepa_instance is None:
|
| 346 |
+
_gepa_instance = GEPAOptimizer()
|
| 347 |
+
return _gepa_instance
|
backend/main.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQL Agent OpenEnv β FastAPI entry point.
|
| 3 |
+
|
| 4 |
+
Start with:
|
| 5 |
+
uvicorn main:app --reload --port 8000
|
| 6 |
+
|
| 7 |
+
Environment variables:
|
| 8 |
+
API_BASE_URL β OpenAI-compatible base URL
|
| 9 |
+
MODEL_NAME β model name
|
| 10 |
+
HF_TOKEN β API key / bearer token
|
| 11 |
+
DATA_DIR β override data directory (default: ./data)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
from fastapi import FastAPI
|
| 20 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 21 |
+
from fastapi.staticfiles import StaticFiles
|
| 22 |
+
|
| 23 |
+
from api.demo import router as demo_router
|
| 24 |
+
from api.openenv import router as openenv_router, ResetRequest, StepRequest, env_reset, env_step, env_state
|
| 25 |
+
from env.database import ensure_seeded
|
| 26 |
+
|
| 27 |
+
app = FastAPI(
|
| 28 |
+
title="SQL Agent OpenEnv",
|
| 29 |
+
description=(
|
| 30 |
+
"A SQL generation environment powered by a LinUCB contextual bandit "
|
| 31 |
+
"and GEPA prompt evolution, built for the Meta + Hugging Face OpenEnv hackathon."
|
| 32 |
+
),
|
| 33 |
+
version="1.0.0",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# βββ CORS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
|
| 38 |
+
app.add_middleware(
|
| 39 |
+
CORSMiddleware,
|
| 40 |
+
allow_origins=["*"],
|
| 41 |
+
allow_credentials=True,
|
| 42 |
+
allow_methods=["*"],
|
| 43 |
+
allow_headers=["*"],
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# βββ Routers βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
+
|
| 48 |
+
app.include_router(demo_router, prefix="/api", tags=["demo"])
|
| 49 |
+
app.include_router(openenv_router, prefix="/env", tags=["openenv"])
|
| 50 |
+
|
| 51 |
+
# βββ Top-level OpenEnv aliases (required by openenv validate + pre-validation) β
|
| 52 |
+
# The validator pings POST <url>/reset β these mirror /env/* without the prefix.
|
| 53 |
+
|
| 54 |
+
@app.post("/reset", tags=["openenv"])
|
| 55 |
+
async def root_reset(req: ResetRequest = None):
|
| 56 |
+
return await env_reset(req or ResetRequest())
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@app.post("/step", tags=["openenv"])
|
| 60 |
+
async def root_step(req: StepRequest = None):
|
| 61 |
+
return await env_step(req or StepRequest())
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@app.get("/state", tags=["openenv"])
|
| 65 |
+
async def root_state():
|
| 66 |
+
return await env_state()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# βββ Health check ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
|
| 71 |
+
@app.get("/health", tags=["system"])
|
| 72 |
+
async def health():
|
| 73 |
+
return {"status": "ok", "service": "sql-agent-openenv"}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# βββ Startup βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
+
|
| 78 |
+
@app.on_event("startup")
|
| 79 |
+
async def startup_event():
|
| 80 |
+
"""Seed the database on first startup."""
|
| 81 |
+
try:
|
| 82 |
+
ensure_seeded()
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Warning: database seed failed: {e}")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# βββ Static files (frontend) β mount last βββββββββββββββββββββββββ
|
| 88 |
+
|
| 89 |
+
_frontend_dist = Path(__file__).parent.parent / "frontend" / "dist"
|
| 90 |
+
if _frontend_dist.exists():
|
| 91 |
+
app.mount(
|
| 92 |
+
"/",
|
| 93 |
+
StaticFiles(directory=str(_frontend_dist), html=True),
|
| 94 |
+
name="frontend",
|
| 95 |
+
)
|
| 96 |
+
else:
|
| 97 |
+
@app.get("/", tags=["system"])
|
| 98 |
+
async def root():
|
| 99 |
+
return {
|
| 100 |
+
"message": "SQL Agent OpenEnv API",
|
| 101 |
+
"docs": "/docs",
|
| 102 |
+
"health": "/health",
|
| 103 |
+
"env_info": "/env/info",
|
| 104 |
+
}
|
backend/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.115.0
|
| 2 |
+
uvicorn[standard]>=0.30.0
|
| 3 |
+
openai>=1.40.0
|
| 4 |
+
pydantic>=2.8.0
|
| 5 |
+
numpy>=1.26.0
|
| 6 |
+
aiofiles>=24.0.0
|
| 7 |
+
python-multipart>=0.0.9
|
| 8 |
+
sse-starlette>=2.1.0
|
| 9 |
+
aiosqlite>=0.20.0
|
backend/rl/__init__.py
ADDED
|
File without changes
|
backend/rl/environment.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQLDebugEnvironment β Gym-like RL environment for the SQL debug loop.
|
| 3 |
+
|
| 4 |
+
Lifecycle:
|
| 5 |
+
1. env.reset(question) β start new episode
|
| 6 |
+
2. env.observe_error(error, sql) β classify error, build state
|
| 7 |
+
3. env.select_action() β bandit picks repair strategy
|
| 8 |
+
4. env.get_repair_prompt(...) β get specialized prompt for chosen action
|
| 9 |
+
5. env.record_step(success) β record outcome, compute reward
|
| 10 |
+
6. Repeat 2-5 until success or max attempts
|
| 11 |
+
7. env.end_episode(success) β finalize, HER relabeling, bandit update
|
| 12 |
+
|
| 13 |
+
This module is a stateful singleton β one active episode at a time.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import time
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from rl.types import (
|
| 22 |
+
RLState,
|
| 23 |
+
RepairAction,
|
| 24 |
+
ErrorClass,
|
| 25 |
+
EpisodeStep,
|
| 26 |
+
RLMetrics,
|
| 27 |
+
featurize,
|
| 28 |
+
REPAIR_ACTION_NAMES,
|
| 29 |
+
ERROR_CLASS_NAMES,
|
| 30 |
+
)
|
| 31 |
+
from rl.error_classifier import classify_error, extract_offending_token
|
| 32 |
+
from rl.grader import GraderInput, compute_reward
|
| 33 |
+
from rl.linucb import LinUCB
|
| 34 |
+
from rl.experience import record_episode, get_metrics, reset_experience
|
| 35 |
+
from rl.repair_strategies import (
|
| 36 |
+
RepairContext,
|
| 37 |
+
get_repair_system_suffix,
|
| 38 |
+
build_repair_user_message,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# βββ Singleton State βββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
|
| 43 |
+
_bandit: Optional[LinUCB] = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class _EpisodeContext:
|
| 47 |
+
def __init__(self, question: str) -> None:
|
| 48 |
+
self.question = question
|
| 49 |
+
self.steps: list[EpisodeStep] = []
|
| 50 |
+
self.previous_error_class: Optional[ErrorClass] = None
|
| 51 |
+
self.consecutive_same_error: int = 0
|
| 52 |
+
self.last_action: Optional[RepairAction] = None
|
| 53 |
+
self.current_state: Optional[RLState] = None
|
| 54 |
+
self.current_features: Optional[list[float]] = None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
_current_episode: Optional[_EpisodeContext] = None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _get_bandit() -> LinUCB:
|
| 61 |
+
global _bandit
|
| 62 |
+
if _bandit is None:
|
| 63 |
+
_bandit = LinUCB()
|
| 64 |
+
return _bandit
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# βββ Environment Interface ββββββββββββββββββββββββββββββββββββββββ
|
| 68 |
+
|
| 69 |
+
def reset(question: str) -> None:
|
| 70 |
+
"""Start a new episode. If a previous episode was active, end it as failure."""
|
| 71 |
+
global _current_episode
|
| 72 |
+
if _current_episode and _current_episode.steps:
|
| 73 |
+
end_episode(False)
|
| 74 |
+
_current_episode = _EpisodeContext(question)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def observe_error(
|
| 78 |
+
error_message: str,
|
| 79 |
+
failing_sql: str,
|
| 80 |
+
attempt_number: int,
|
| 81 |
+
) -> dict:
|
| 82 |
+
"""
|
| 83 |
+
Classify the SQL execution error and build the RL state.
|
| 84 |
+
Returns a dict with keys: error_class, error_class_name, state.
|
| 85 |
+
"""
|
| 86 |
+
if _current_episode is None:
|
| 87 |
+
raise RuntimeError("Call reset() before observe_error()")
|
| 88 |
+
|
| 89 |
+
error_class = classify_error(error_message)
|
| 90 |
+
error_changed = (
|
| 91 |
+
_current_episode.previous_error_class is not None
|
| 92 |
+
and _current_episode.previous_error_class != error_class
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
if _current_episode.previous_error_class == error_class:
|
| 96 |
+
_current_episode.consecutive_same_error += 1
|
| 97 |
+
else:
|
| 98 |
+
_current_episode.consecutive_same_error = 1
|
| 99 |
+
|
| 100 |
+
state = RLState(
|
| 101 |
+
error_class=error_class,
|
| 102 |
+
attempt_number=attempt_number,
|
| 103 |
+
previous_action=_current_episode.last_action,
|
| 104 |
+
error_changed=error_changed,
|
| 105 |
+
consecutive_same_error=_current_episode.consecutive_same_error,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
_current_episode.current_state = state
|
| 109 |
+
_current_episode.current_features = featurize(state)
|
| 110 |
+
|
| 111 |
+
return {
|
| 112 |
+
"error_class": error_class,
|
| 113 |
+
"error_class_name": ERROR_CLASS_NAMES[error_class],
|
| 114 |
+
"state": state,
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def select_action() -> dict:
|
| 119 |
+
"""
|
| 120 |
+
Ask the bandit to select a repair action based on current state.
|
| 121 |
+
Returns dict with keys: action, action_name, scores.
|
| 122 |
+
"""
|
| 123 |
+
if _current_episode is None or _current_episode.current_features is None:
|
| 124 |
+
raise RuntimeError("Call observe_error() before select_action()")
|
| 125 |
+
|
| 126 |
+
b = _get_bandit()
|
| 127 |
+
action, scores = b.select_action(_current_episode.current_features)
|
| 128 |
+
_current_episode.last_action = action
|
| 129 |
+
|
| 130 |
+
return {
|
| 131 |
+
"action": action,
|
| 132 |
+
"action_name": REPAIR_ACTION_NAMES[action],
|
| 133 |
+
"scores": scores,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_repair_prompt(
|
| 138 |
+
action: RepairAction,
|
| 139 |
+
schema: str,
|
| 140 |
+
question: str,
|
| 141 |
+
failing_sql: str,
|
| 142 |
+
error_message: str,
|
| 143 |
+
) -> dict:
|
| 144 |
+
"""
|
| 145 |
+
Build the system suffix and user message for the chosen repair action.
|
| 146 |
+
Returns dict with keys: system_suffix, user_message.
|
| 147 |
+
"""
|
| 148 |
+
offending_token = extract_offending_token(error_message)
|
| 149 |
+
ctx = RepairContext(
|
| 150 |
+
schema=schema,
|
| 151 |
+
question=question,
|
| 152 |
+
failing_sql=failing_sql,
|
| 153 |
+
error_message=error_message,
|
| 154 |
+
offending_token=offending_token,
|
| 155 |
+
)
|
| 156 |
+
return {
|
| 157 |
+
"system_suffix": get_repair_system_suffix(action),
|
| 158 |
+
"user_message": build_repair_user_message(action, ctx),
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def record_step(
|
| 163 |
+
action: RepairAction,
|
| 164 |
+
success: bool,
|
| 165 |
+
error_message: str,
|
| 166 |
+
sql: str,
|
| 167 |
+
) -> dict:
|
| 168 |
+
"""
|
| 169 |
+
Record the outcome of a repair step and compute shaped reward.
|
| 170 |
+
Returns dict with keys: reward, breakdown.
|
| 171 |
+
"""
|
| 172 |
+
if _current_episode is None or _current_episode.current_state is None:
|
| 173 |
+
raise RuntimeError("Call observe_error() before record_step()")
|
| 174 |
+
|
| 175 |
+
state = _current_episode.current_state
|
| 176 |
+
|
| 177 |
+
grader_input = GraderInput(
|
| 178 |
+
success=success,
|
| 179 |
+
attempt_number=state.attempt_number,
|
| 180 |
+
current_error_class=None if success else classify_error(error_message),
|
| 181 |
+
previous_error_class=_current_episode.previous_error_class,
|
| 182 |
+
)
|
| 183 |
+
result = compute_reward(grader_input)
|
| 184 |
+
|
| 185 |
+
step = EpisodeStep(
|
| 186 |
+
state=state,
|
| 187 |
+
featurized=_current_episode.current_features or featurize(state),
|
| 188 |
+
action=action,
|
| 189 |
+
reward=result.reward,
|
| 190 |
+
error_message=error_message,
|
| 191 |
+
sql=sql,
|
| 192 |
+
success=success,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
_current_episode.steps.append(step)
|
| 196 |
+
_current_episode.previous_error_class = state.error_class
|
| 197 |
+
|
| 198 |
+
return {
|
| 199 |
+
"reward": result.reward,
|
| 200 |
+
"breakdown": {
|
| 201 |
+
"base": result.breakdown.base,
|
| 202 |
+
"attempt_penalty": result.breakdown.attempt_penalty,
|
| 203 |
+
"severity_bonus": result.breakdown.severity_bonus,
|
| 204 |
+
"change_bonus": result.breakdown.change_bonus,
|
| 205 |
+
},
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def end_episode(success: bool) -> Optional[dict]:
|
| 210 |
+
"""
|
| 211 |
+
End the current episode. Runs HER relabeling and updates the bandit.
|
| 212 |
+
Returns dict with keys: total_reward, episode_length.
|
| 213 |
+
"""
|
| 214 |
+
global _current_episode
|
| 215 |
+
if _current_episode is None or not _current_episode.steps:
|
| 216 |
+
_current_episode = None
|
| 217 |
+
return None
|
| 218 |
+
|
| 219 |
+
b = _get_bandit()
|
| 220 |
+
episode, relabeled = record_episode(
|
| 221 |
+
_current_episode.question,
|
| 222 |
+
_current_episode.steps,
|
| 223 |
+
success,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
for exp in relabeled:
|
| 227 |
+
b.update(exp.state, exp.action, exp.reward)
|
| 228 |
+
|
| 229 |
+
b.decay_alpha()
|
| 230 |
+
|
| 231 |
+
result = {
|
| 232 |
+
"total_reward": episode.total_reward,
|
| 233 |
+
"episode_length": len(episode.steps),
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
_current_episode = None
|
| 237 |
+
return result
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# βββ Query Interface ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 241 |
+
|
| 242 |
+
def get_rl_metrics() -> RLMetrics:
|
| 243 |
+
return get_metrics()
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def get_bandit_state() -> dict:
|
| 247 |
+
b = _get_bandit()
|
| 248 |
+
return {
|
| 249 |
+
"action_counts": b.get_action_counts(),
|
| 250 |
+
"total_updates": b.get_total_updates(),
|
| 251 |
+
"alpha": b.get_alpha(),
|
| 252 |
+
"action_distribution": b.get_action_distribution(),
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def is_episode_active() -> bool:
|
| 257 |
+
return _current_episode is not None
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def reset_rl() -> None:
|
| 261 |
+
"""Reset the entire RL system β bandit weights and experience store."""
|
| 262 |
+
global _bandit, _current_episode
|
| 263 |
+
if _bandit:
|
| 264 |
+
_bandit.reset()
|
| 265 |
+
reset_experience()
|
| 266 |
+
_current_episode = None
|
backend/rl/error_classifier.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQL error classifier: maps raw SQLite error messages to one of 8
|
| 3 |
+
canonical ErrorClass values.
|
| 4 |
+
|
| 5 |
+
Severity ordering (lower = less severe / closer to correct):
|
| 6 |
+
OTHER=5, SYNTAX_ERROR=4, NO_SUCH_FUNCTION=3, NO_SUCH_TABLE=3,
|
| 7 |
+
DATATYPE_MISMATCH=2, AGGREGATION_ERROR=2,
|
| 8 |
+
NO_SUCH_COLUMN=1, AMBIGUOUS_COLUMN=1
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import re
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
from rl.types import ErrorClass
|
| 15 |
+
|
| 16 |
+
_SEVERITY: dict[ErrorClass, int] = {
|
| 17 |
+
ErrorClass.OTHER: 5,
|
| 18 |
+
ErrorClass.SYNTAX_ERROR: 4,
|
| 19 |
+
ErrorClass.NO_SUCH_FUNCTION: 3,
|
| 20 |
+
ErrorClass.NO_SUCH_TABLE: 3,
|
| 21 |
+
ErrorClass.DATATYPE_MISMATCH: 2,
|
| 22 |
+
ErrorClass.AGGREGATION_ERROR: 2,
|
| 23 |
+
ErrorClass.NO_SUCH_COLUMN: 1,
|
| 24 |
+
ErrorClass.AMBIGUOUS_COLUMN: 1,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def error_severity(error_class: ErrorClass) -> int:
|
| 29 |
+
return _SEVERITY[error_class]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def classify_error(error_message: str) -> ErrorClass:
|
| 33 |
+
"""
|
| 34 |
+
Classify a raw SQLite error message into one of 8 canonical classes.
|
| 35 |
+
Patterns are ordered most-specific-first to avoid false matches.
|
| 36 |
+
"""
|
| 37 |
+
msg = error_message.lower()
|
| 38 |
+
|
| 39 |
+
# Column-level errors
|
| 40 |
+
if "no such column" in msg:
|
| 41 |
+
return ErrorClass.NO_SUCH_COLUMN
|
| 42 |
+
if "ambiguous column" in msg:
|
| 43 |
+
return ErrorClass.AMBIGUOUS_COLUMN
|
| 44 |
+
|
| 45 |
+
# Table-level errors
|
| 46 |
+
if "no such table" in msg:
|
| 47 |
+
return ErrorClass.NO_SUCH_TABLE
|
| 48 |
+
|
| 49 |
+
# Function errors
|
| 50 |
+
if "no such function" in msg:
|
| 51 |
+
return ErrorClass.NO_SUCH_FUNCTION
|
| 52 |
+
|
| 53 |
+
# Aggregation / GROUP BY
|
| 54 |
+
if (
|
| 55 |
+
"not an aggregate" in msg
|
| 56 |
+
or "misuse of aggregate" in msg
|
| 57 |
+
or ("group by" in msg and "must appear" in msg)
|
| 58 |
+
or "must be an aggregate" in msg
|
| 59 |
+
):
|
| 60 |
+
return ErrorClass.AGGREGATION_ERROR
|
| 61 |
+
|
| 62 |
+
# Syntax errors (broad β must come after more specific patterns)
|
| 63 |
+
if "syntax error" in msg or re.search(r'near\s+"', msg):
|
| 64 |
+
return ErrorClass.SYNTAX_ERROR
|
| 65 |
+
|
| 66 |
+
# Type errors
|
| 67 |
+
if "datatype mismatch" in msg or "type mismatch" in msg:
|
| 68 |
+
return ErrorClass.DATATYPE_MISMATCH
|
| 69 |
+
|
| 70 |
+
return ErrorClass.OTHER
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def extract_offending_token(error_message: str) -> Optional[str]:
|
| 74 |
+
"""
|
| 75 |
+
Extract the offending token from a SQLite error message.
|
| 76 |
+
Returns None if no specific token can be identified.
|
| 77 |
+
"""
|
| 78 |
+
# "no such column: X"
|
| 79 |
+
m = re.search(r"no such column:\s*(\S+)", error_message, re.IGNORECASE)
|
| 80 |
+
if m:
|
| 81 |
+
return m.group(1)
|
| 82 |
+
|
| 83 |
+
# "no such table: X"
|
| 84 |
+
m = re.search(r"no such table:\s*(\S+)", error_message, re.IGNORECASE)
|
| 85 |
+
if m:
|
| 86 |
+
return m.group(1)
|
| 87 |
+
|
| 88 |
+
# 'near "X": syntax error'
|
| 89 |
+
m = re.search(r'near\s+"([^"]+)"', error_message, re.IGNORECASE)
|
| 90 |
+
if m:
|
| 91 |
+
return m.group(1)
|
| 92 |
+
|
| 93 |
+
# "no such function: X"
|
| 94 |
+
m = re.search(r"no such function:\s*(\S+)", error_message, re.IGNORECASE)
|
| 95 |
+
if m:
|
| 96 |
+
return m.group(1)
|
| 97 |
+
|
| 98 |
+
return None
|
backend/rl/experience.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Experience store: logs episodes, persists to disk, and implements
|
| 3 |
+
Hindsight Experience Replay (HER) for reward relabeling.
|
| 4 |
+
|
| 5 |
+
HER (Andrychowicz et al., 2017): If a later attempt in the same episode
|
| 6 |
+
succeeded, relabel earlier failed steps with partial credit proportional
|
| 7 |
+
to their distance from the success step. This multiplies the effective
|
| 8 |
+
training signal from sparse rewards.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
import random
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
from rl.types import (
|
| 21 |
+
Episode,
|
| 22 |
+
EpisodeStep,
|
| 23 |
+
Experience,
|
| 24 |
+
RLMetrics,
|
| 25 |
+
RepairAction,
|
| 26 |
+
REPAIR_ACTION_NAMES,
|
| 27 |
+
ERROR_CLASS_NAMES,
|
| 28 |
+
)
|
| 29 |
+
from rl.grader import compute_episode_reward
|
| 30 |
+
|
| 31 |
+
_DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
|
| 32 |
+
EXPERIENCE_PATH = _DATA_DIR / "rl_experiences.json"
|
| 33 |
+
MAX_EPISODES = 500
|
| 34 |
+
|
| 35 |
+
_episodes: list[Episode] = []
|
| 36 |
+
_loaded: bool = False
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _ensure_loaded() -> None:
|
| 40 |
+
global _loaded, _episodes
|
| 41 |
+
if _loaded:
|
| 42 |
+
return
|
| 43 |
+
_loaded = True
|
| 44 |
+
try:
|
| 45 |
+
if EXPERIENCE_PATH.exists():
|
| 46 |
+
raw = json.loads(EXPERIENCE_PATH.read_text())
|
| 47 |
+
_episodes = [Episode(**ep) for ep in raw]
|
| 48 |
+
except Exception:
|
| 49 |
+
_episodes = []
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _persist() -> None:
|
| 53 |
+
try:
|
| 54 |
+
EXPERIENCE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 55 |
+
data = [ep.model_dump() for ep in _episodes[-MAX_EPISODES:]]
|
| 56 |
+
EXPERIENCE_PATH.write_text(json.dumps(data, default=str))
|
| 57 |
+
except Exception:
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def record_episode(
|
| 62 |
+
question: str,
|
| 63 |
+
steps: list[EpisodeStep],
|
| 64 |
+
success: bool,
|
| 65 |
+
) -> tuple[Episode, list[Experience]]:
|
| 66 |
+
"""
|
| 67 |
+
Record a completed episode, run HER relabeling, and persist.
|
| 68 |
+
Returns (episode, relabeled_experiences).
|
| 69 |
+
"""
|
| 70 |
+
_ensure_loaded()
|
| 71 |
+
|
| 72 |
+
step_rewards = [s.reward for s in steps]
|
| 73 |
+
total_reward = compute_episode_reward(step_rewards, success)
|
| 74 |
+
|
| 75 |
+
episode = Episode(
|
| 76 |
+
id=f"ep-{int(time.time() * 1000)}-{random.randint(1000, 9999)}",
|
| 77 |
+
question=question,
|
| 78 |
+
steps=steps,
|
| 79 |
+
total_reward=total_reward,
|
| 80 |
+
success=success,
|
| 81 |
+
timestamp=time.time(),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
_episodes.append(episode)
|
| 85 |
+
if len(_episodes) > MAX_EPISODES:
|
| 86 |
+
_episodes[:] = _episodes[-MAX_EPISODES:]
|
| 87 |
+
_persist()
|
| 88 |
+
|
| 89 |
+
relabeled = _apply_her(episode)
|
| 90 |
+
return episode, relabeled
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _apply_her(episode: Episode) -> list[Experience]:
|
| 94 |
+
"""
|
| 95 |
+
Hindsight Experience Replay.
|
| 96 |
+
|
| 97 |
+
If the episode eventually succeeded at step T, relabel earlier
|
| 98 |
+
failed steps with a hindsight bonus:
|
| 99 |
+
bonus(t) = 0.3 * (1 - (T - t) / T)
|
| 100 |
+
|
| 101 |
+
Steps closer to the eventual success receive more credit.
|
| 102 |
+
"""
|
| 103 |
+
experiences: list[Experience] = []
|
| 104 |
+
success_step_idx = next(
|
| 105 |
+
(i for i, s in enumerate(episode.steps) if s.success), -1
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
for t, step in enumerate(episode.steps):
|
| 109 |
+
reward = step.reward
|
| 110 |
+
|
| 111 |
+
if success_step_idx > t:
|
| 112 |
+
distance = success_step_idx - t
|
| 113 |
+
total_steps = len(episode.steps)
|
| 114 |
+
her_bonus = 0.3 * (1.0 - distance / total_steps)
|
| 115 |
+
reward += her_bonus
|
| 116 |
+
|
| 117 |
+
next_step = episode.steps[t + 1] if t < len(episode.steps) - 1 else None
|
| 118 |
+
|
| 119 |
+
experiences.append(
|
| 120 |
+
Experience(
|
| 121 |
+
state=step.featurized,
|
| 122 |
+
action=step.action,
|
| 123 |
+
reward=reward,
|
| 124 |
+
next_state=next_step.featurized if next_step else None,
|
| 125 |
+
done=(t == len(episode.steps) - 1),
|
| 126 |
+
timestamp=episode.timestamp,
|
| 127 |
+
metadata={
|
| 128 |
+
"question": episode.question,
|
| 129 |
+
"error_message": step.error_message,
|
| 130 |
+
"sql": step.sql,
|
| 131 |
+
"error_class": int(step.state.error_class),
|
| 132 |
+
"attempt_number": step.state.attempt_number,
|
| 133 |
+
},
|
| 134 |
+
)
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
return experiences
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def replay_all(bandit) -> int:
|
| 141 |
+
"""
|
| 142 |
+
Replay all stored experiences through the bandit to rebuild weights.
|
| 143 |
+
Useful after a reset or if weights are lost.
|
| 144 |
+
"""
|
| 145 |
+
_ensure_loaded()
|
| 146 |
+
count = 0
|
| 147 |
+
for ep in _episodes:
|
| 148 |
+
relabeled = _apply_her(ep)
|
| 149 |
+
for exp in relabeled:
|
| 150 |
+
bandit.update(exp.state, exp.action, exp.reward)
|
| 151 |
+
count += 1
|
| 152 |
+
return count
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_metrics() -> RLMetrics:
|
| 156 |
+
_ensure_loaded()
|
| 157 |
+
|
| 158 |
+
recent_window = 50
|
| 159 |
+
recent = _episodes[-recent_window:]
|
| 160 |
+
all_steps = [s for ep in _episodes for s in ep.steps]
|
| 161 |
+
|
| 162 |
+
action_dist: dict[str, int] = {}
|
| 163 |
+
error_dist: dict[str, int] = {}
|
| 164 |
+
|
| 165 |
+
for step in all_steps:
|
| 166 |
+
a_name = REPAIR_ACTION_NAMES[step.action]
|
| 167 |
+
action_dist[a_name] = action_dist.get(a_name, 0) + 1
|
| 168 |
+
e_name = ERROR_CLASS_NAMES[step.state.error_class]
|
| 169 |
+
error_dist[e_name] = error_dist.get(e_name, 0) + 1
|
| 170 |
+
|
| 171 |
+
return RLMetrics(
|
| 172 |
+
total_episodes=len(_episodes),
|
| 173 |
+
total_steps=len(all_steps),
|
| 174 |
+
cumulative_reward=sum(ep.total_reward for ep in _episodes),
|
| 175 |
+
success_rate=(
|
| 176 |
+
sum(1 for ep in recent if ep.success) / len(recent)
|
| 177 |
+
if recent
|
| 178 |
+
else 0.0
|
| 179 |
+
),
|
| 180 |
+
avg_attempts=(
|
| 181 |
+
sum(len(ep.steps) for ep in recent) / len(recent)
|
| 182 |
+
if recent
|
| 183 |
+
else 0.0
|
| 184 |
+
),
|
| 185 |
+
action_distribution=action_dist,
|
| 186 |
+
error_distribution=error_dist,
|
| 187 |
+
reward_history=[ep.total_reward for ep in _episodes],
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def get_episodes() -> list[Episode]:
|
| 192 |
+
_ensure_loaded()
|
| 193 |
+
return list(_episodes)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_recent_episodes(n: int) -> list[Episode]:
|
| 197 |
+
_ensure_loaded()
|
| 198 |
+
return _episodes[-n:]
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def reset_experience() -> None:
|
| 202 |
+
global _episodes, _loaded
|
| 203 |
+
_episodes = []
|
| 204 |
+
_loaded = True
|
| 205 |
+
try:
|
| 206 |
+
EXPERIENCE_PATH.unlink(missing_ok=True)
|
| 207 |
+
except Exception:
|
| 208 |
+
pass
|
backend/rl/grader.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shaped reward function for the SQL debug RL environment.
|
| 3 |
+
|
| 4 |
+
Reward components:
|
| 5 |
+
+1.0 base success reward
|
| 6 |
+
-0.1 per attempt (attempt penalty β incentivizes early resolution)
|
| 7 |
+
+0.2 if error severity decreased (progress signal)
|
| 8 |
+
+0.1 if error class changed at all (exploration signal)
|
| 9 |
+
-0.1 base failure penalty per step
|
| 10 |
+
|
| 11 |
+
The shaping is potential-based (Ng et al., 1999), preserving
|
| 12 |
+
the optimal policy while accelerating learning.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from typing import Optional
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
|
| 20 |
+
from rl.types import ErrorClass
|
| 21 |
+
from rl.error_classifier import error_severity
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class GraderInput:
|
| 26 |
+
success: bool
|
| 27 |
+
attempt_number: int # 1-indexed
|
| 28 |
+
current_error_class: Optional[ErrorClass] # None if success
|
| 29 |
+
previous_error_class: Optional[ErrorClass] # None on first attempt
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class RewardBreakdown:
|
| 34 |
+
base: float
|
| 35 |
+
attempt_penalty: float
|
| 36 |
+
severity_bonus: float
|
| 37 |
+
change_bonus: float
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class GraderOutput:
|
| 42 |
+
reward: float
|
| 43 |
+
breakdown: RewardBreakdown
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def compute_reward(inp: GraderInput) -> GraderOutput:
|
| 47 |
+
if inp.success:
|
| 48 |
+
base = 1.0
|
| 49 |
+
attempt_penalty = -0.1 * (inp.attempt_number - 1)
|
| 50 |
+
return GraderOutput(
|
| 51 |
+
reward=base + attempt_penalty,
|
| 52 |
+
breakdown=RewardBreakdown(
|
| 53 |
+
base=base,
|
| 54 |
+
attempt_penalty=attempt_penalty,
|
| 55 |
+
severity_bonus=0.0,
|
| 56 |
+
change_bonus=0.0,
|
| 57 |
+
),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Failed step β base penalty + potential shaping
|
| 61 |
+
base = -0.1
|
| 62 |
+
attempt_penalty = -0.05 * inp.attempt_number
|
| 63 |
+
|
| 64 |
+
severity_bonus = 0.0
|
| 65 |
+
change_bonus = 0.0
|
| 66 |
+
|
| 67 |
+
if inp.previous_error_class is not None and inp.current_error_class is not None:
|
| 68 |
+
prev_sev = error_severity(inp.previous_error_class)
|
| 69 |
+
curr_sev = error_severity(inp.current_error_class)
|
| 70 |
+
|
| 71 |
+
if curr_sev < prev_sev:
|
| 72 |
+
severity_bonus = 0.2 # Progress toward solution
|
| 73 |
+
elif curr_sev > prev_sev:
|
| 74 |
+
severity_bonus = -0.1 # Regression
|
| 75 |
+
|
| 76 |
+
if inp.current_error_class != inp.previous_error_class:
|
| 77 |
+
change_bonus = 0.1 # At least something different happened
|
| 78 |
+
|
| 79 |
+
reward = base + attempt_penalty + severity_bonus + change_bonus
|
| 80 |
+
|
| 81 |
+
return GraderOutput(
|
| 82 |
+
reward=reward,
|
| 83 |
+
breakdown=RewardBreakdown(
|
| 84 |
+
base=base,
|
| 85 |
+
attempt_penalty=attempt_penalty,
|
| 86 |
+
severity_bonus=severity_bonus,
|
| 87 |
+
change_bonus=change_bonus,
|
| 88 |
+
),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def compute_episode_reward(step_rewards: list[float], success: bool) -> float:
|
| 93 |
+
"""
|
| 94 |
+
Compute total episode reward from individual step rewards.
|
| 95 |
+
Includes a terminal bonus/penalty based on final outcome.
|
| 96 |
+
"""
|
| 97 |
+
total = sum(step_rewards)
|
| 98 |
+
terminal = 0.5 if success else -0.5
|
| 99 |
+
return total + terminal
|
backend/rl/linucb.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LinUCB Contextual Bandit (Li et al., 2010).
|
| 3 |
+
|
| 4 |
+
Maintains per-action inverse covariance matrices using the
|
| 5 |
+
Sherman-Morrison rank-1 update formula for O(d^2) updates.
|
| 6 |
+
|
| 7 |
+
For each action a in {0..K-1}:
|
| 8 |
+
A_inv[a] β dΓd inverse covariance (starts as I_d)
|
| 9 |
+
b[a] β d reward-weighted feature accumulator
|
| 10 |
+
theta[a] = A_inv[a] @ b[a] (ridge regression estimate)
|
| 11 |
+
UCB_a(x) = theta[a] @ x + alpha * sqrt(max(0, x @ A_inv[a] @ x))
|
| 12 |
+
|
| 13 |
+
Action selection: argmax_a UCB_a(x)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import random
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import List, Optional, Tuple
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
from rl.types import FEATURE_DIM, NUM_ACTIONS, RepairAction, REPAIR_ACTION_NAMES
|
| 27 |
+
|
| 28 |
+
# Default path β can be overridden by DATA_DIR env var
|
| 29 |
+
_DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
|
| 30 |
+
WEIGHTS_PATH = _DATA_DIR / "rl_weights.json"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LinUCB:
|
| 34 |
+
"""
|
| 35 |
+
LinUCB contextual bandit with Sherman-Morrison updates and alpha decay.
|
| 36 |
+
Weights are persisted to JSON after every 10 updates.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
d: int = FEATURE_DIM,
|
| 42 |
+
K: int = NUM_ACTIONS,
|
| 43 |
+
alpha: float = 1.5,
|
| 44 |
+
) -> None:
|
| 45 |
+
self.d = d
|
| 46 |
+
self.K = K
|
| 47 |
+
self.alpha = alpha
|
| 48 |
+
self.total_updates = 0
|
| 49 |
+
|
| 50 |
+
loaded = self._load_weights()
|
| 51 |
+
if loaded is not None:
|
| 52 |
+
self.A_inv = loaded["A_inv"]
|
| 53 |
+
self.b = loaded["b"]
|
| 54 |
+
self.counts = loaded["counts"]
|
| 55 |
+
self.total_updates = loaded["total_updates"]
|
| 56 |
+
else:
|
| 57 |
+
self.A_inv: List[np.ndarray] = [np.eye(d) for _ in range(K)]
|
| 58 |
+
self.b: List[np.ndarray] = [np.zeros(d) for _ in range(K)]
|
| 59 |
+
self.counts: List[int] = [0] * K
|
| 60 |
+
|
| 61 |
+
# βββ Core Interface ββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
|
| 63 |
+
def select_action(self, x: List[float]) -> Tuple[RepairAction, List[float]]:
|
| 64 |
+
"""
|
| 65 |
+
Select the action with highest UCB score.
|
| 66 |
+
Returns (action, scores_for_all_actions).
|
| 67 |
+
"""
|
| 68 |
+
xv = np.array(x, dtype=np.float64)
|
| 69 |
+
scores = []
|
| 70 |
+
|
| 71 |
+
for a in range(self.K):
|
| 72 |
+
theta = self.A_inv[a] @ self.b[a]
|
| 73 |
+
exploit = float(theta @ xv)
|
| 74 |
+
quad = float(xv @ self.A_inv[a] @ xv)
|
| 75 |
+
explore = self.alpha * float(np.sqrt(max(0.0, quad)))
|
| 76 |
+
scores.append(exploit + explore)
|
| 77 |
+
|
| 78 |
+
# Argmax with random tie-breaking
|
| 79 |
+
best_action = 0
|
| 80 |
+
best_score = scores[0]
|
| 81 |
+
for a in range(1, self.K):
|
| 82 |
+
if scores[a] > best_score or (
|
| 83 |
+
scores[a] == best_score and random.random() > 0.5
|
| 84 |
+
):
|
| 85 |
+
best_score = scores[a]
|
| 86 |
+
best_action = a
|
| 87 |
+
|
| 88 |
+
return RepairAction(best_action), scores
|
| 89 |
+
|
| 90 |
+
def update(self, x: List[float], action: RepairAction, reward: float) -> None:
|
| 91 |
+
"""
|
| 92 |
+
Update the model after observing a reward.
|
| 93 |
+
Uses Sherman-Morrison: (A + xx^T)^{-1} = A^{-1} - (A^{-1}xx^T A^{-1}) / (1 + x^T A^{-1} x)
|
| 94 |
+
"""
|
| 95 |
+
a = int(action)
|
| 96 |
+
xv = np.array(x, dtype=np.float64)
|
| 97 |
+
|
| 98 |
+
A_inv_x = self.A_inv[a] @ xv # shape (d,)
|
| 99 |
+
denom = 1.0 + float(xv @ A_inv_x) # scalar
|
| 100 |
+
|
| 101 |
+
# Rank-1 downdate
|
| 102 |
+
self.A_inv[a] -= np.outer(A_inv_x, A_inv_x) / denom
|
| 103 |
+
|
| 104 |
+
# Reward-weighted feature accumulation
|
| 105 |
+
self.b[a] += reward * xv
|
| 106 |
+
|
| 107 |
+
self.counts[a] += 1
|
| 108 |
+
self.total_updates += 1
|
| 109 |
+
|
| 110 |
+
if self.total_updates % 10 == 0:
|
| 111 |
+
self.save_weights()
|
| 112 |
+
|
| 113 |
+
def get_estimated_rewards(self, x: List[float]) -> List[float]:
|
| 114 |
+
"""
|
| 115 |
+
Return theta^T x for each action (no exploration bonus).
|
| 116 |
+
Useful for understanding learned policy.
|
| 117 |
+
"""
|
| 118 |
+
xv = np.array(x, dtype=np.float64)
|
| 119 |
+
return [float((self.A_inv[a] @ self.b[a]) @ xv) for a in range(self.K)]
|
| 120 |
+
|
| 121 |
+
def get_action_counts(self) -> List[int]:
|
| 122 |
+
return list(self.counts)
|
| 123 |
+
|
| 124 |
+
def get_total_updates(self) -> int:
|
| 125 |
+
return self.total_updates
|
| 126 |
+
|
| 127 |
+
def get_alpha(self) -> float:
|
| 128 |
+
return self.alpha
|
| 129 |
+
|
| 130 |
+
def decay_alpha(self, min_alpha: float = 0.3) -> None:
|
| 131 |
+
"""Decay exploration coefficient toward exploitation."""
|
| 132 |
+
self.alpha = max(min_alpha, self.alpha * 0.995)
|
| 133 |
+
|
| 134 |
+
def get_action_distribution(self) -> dict:
|
| 135 |
+
total = sum(self.counts) or 1
|
| 136 |
+
return {
|
| 137 |
+
REPAIR_ACTION_NAMES[RepairAction(a)]: self.counts[a] / total
|
| 138 |
+
for a in range(self.K)
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
# βββ Persistence βββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
+
|
| 143 |
+
def save_weights(self) -> None:
|
| 144 |
+
try:
|
| 145 |
+
WEIGHTS_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 146 |
+
data = {
|
| 147 |
+
"A_inv": [m.tolist() for m in self.A_inv],
|
| 148 |
+
"b": [v.tolist() for v in self.b],
|
| 149 |
+
"counts": self.counts,
|
| 150 |
+
"total_updates": self.total_updates,
|
| 151 |
+
"alpha": self.alpha,
|
| 152 |
+
}
|
| 153 |
+
WEIGHTS_PATH.write_text(json.dumps(data))
|
| 154 |
+
except Exception:
|
| 155 |
+
pass # Non-fatal
|
| 156 |
+
|
| 157 |
+
def _load_weights(self) -> Optional[dict]:
|
| 158 |
+
try:
|
| 159 |
+
if not WEIGHTS_PATH.exists():
|
| 160 |
+
return None
|
| 161 |
+
raw = json.loads(WEIGHTS_PATH.read_text())
|
| 162 |
+
A_inv = [np.array(m, dtype=np.float64) for m in raw["A_inv"]]
|
| 163 |
+
b = [np.array(v, dtype=np.float64) for v in raw["b"]]
|
| 164 |
+
# Validate dimensions
|
| 165 |
+
if (
|
| 166 |
+
len(A_inv) == self.K
|
| 167 |
+
and A_inv[0].shape == (self.d, self.d)
|
| 168 |
+
and len(b) == self.K
|
| 169 |
+
and b[0].shape == (self.d,)
|
| 170 |
+
):
|
| 171 |
+
return {
|
| 172 |
+
"A_inv": A_inv,
|
| 173 |
+
"b": b,
|
| 174 |
+
"counts": raw["counts"],
|
| 175 |
+
"total_updates": raw["total_updates"],
|
| 176 |
+
}
|
| 177 |
+
return None
|
| 178 |
+
except Exception:
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
def reset(self) -> None:
|
| 182 |
+
self.A_inv = [np.eye(self.d) for _ in range(self.K)]
|
| 183 |
+
self.b = [np.zeros(self.d) for _ in range(self.K)]
|
| 184 |
+
self.counts = [0] * self.K
|
| 185 |
+
self.total_updates = 0
|
| 186 |
+
self.alpha = 1.5
|
| 187 |
+
try:
|
| 188 |
+
WEIGHTS_PATH.unlink(missing_ok=True)
|
| 189 |
+
except Exception:
|
| 190 |
+
pass
|
backend/rl/repair_strategies.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Repair strategy prompt templates for each of the 8 RepairAction values.
|
| 3 |
+
|
| 4 |
+
Each strategy provides:
|
| 5 |
+
- system_suffix: appended to the base system prompt
|
| 6 |
+
- user_template: callable that builds the user message given a RepairContext
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Optional, Callable
|
| 13 |
+
|
| 14 |
+
from rl.types import RepairAction
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class RepairContext:
|
| 19 |
+
schema: str
|
| 20 |
+
question: str
|
| 21 |
+
failing_sql: str
|
| 22 |
+
error_message: str
|
| 23 |
+
offending_token: Optional[str]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class RepairStrategy:
|
| 28 |
+
action: RepairAction
|
| 29 |
+
name: str
|
| 30 |
+
system_suffix: str
|
| 31 |
+
user_template: Callable[[RepairContext], str]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _tmpl_rewrite_full(ctx: RepairContext) -> str:
|
| 35 |
+
return (
|
| 36 |
+
f"Schema:\n{ctx.schema}\n\n"
|
| 37 |
+
f"Question: {ctx.question}\n\n"
|
| 38 |
+
f"A previous attempt failed with: {ctx.error_message}\n\n"
|
| 39 |
+
"Write a completely new SQL query from scratch. Do NOT reference the previous attempt."
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _tmpl_fix_column(ctx: RepairContext) -> str:
|
| 44 |
+
token_hint = f"\n\nThe problematic column is: {ctx.offending_token}" if ctx.offending_token else ""
|
| 45 |
+
return (
|
| 46 |
+
f"Schema:\n{ctx.schema}\n\n"
|
| 47 |
+
f"Question: {ctx.question}\n\n"
|
| 48 |
+
f"Previous SQL:\n{ctx.failing_sql}\n\n"
|
| 49 |
+
f"Error: {ctx.error_message}"
|
| 50 |
+
f"{token_hint}\n\n"
|
| 51 |
+
"Fix ONLY the column name issue. Check the schema for correct column names."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _tmpl_fix_table(ctx: RepairContext) -> str:
|
| 56 |
+
token_hint = f"\n\nThe problematic table is: {ctx.offending_token}" if ctx.offending_token else ""
|
| 57 |
+
return (
|
| 58 |
+
f"Schema:\n{ctx.schema}\n\n"
|
| 59 |
+
f"Question: {ctx.question}\n\n"
|
| 60 |
+
f"Previous SQL:\n{ctx.failing_sql}\n\n"
|
| 61 |
+
f"Error: {ctx.error_message}"
|
| 62 |
+
f"{token_hint}\n\n"
|
| 63 |
+
"Fix the table name or JOIN issue. Verify all table names exist in the schema."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _tmpl_add_groupby(ctx: RepairContext) -> str:
|
| 68 |
+
return (
|
| 69 |
+
f"Schema:\n{ctx.schema}\n\n"
|
| 70 |
+
f"Question: {ctx.question}\n\n"
|
| 71 |
+
f"Previous SQL:\n{ctx.failing_sql}\n\n"
|
| 72 |
+
f"Error: {ctx.error_message}\n\n"
|
| 73 |
+
"Fix the GROUP BY / aggregation issue. Ensure every non-aggregate column in SELECT is in GROUP BY."
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _tmpl_rewrite_cte(ctx: RepairContext) -> str:
|
| 78 |
+
return (
|
| 79 |
+
f"Schema:\n{ctx.schema}\n\n"
|
| 80 |
+
f"Question: {ctx.question}\n\n"
|
| 81 |
+
f"Previous SQL:\n{ctx.failing_sql}\n\n"
|
| 82 |
+
f"Error: {ctx.error_message}\n\n"
|
| 83 |
+
"Restructure the CTEs or subqueries. Break the query into clear, named WITH clauses."
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _tmpl_fix_syntax(ctx: RepairContext) -> str:
|
| 88 |
+
token_hint = f"\n\nSyntax error near: {ctx.offending_token}" if ctx.offending_token else ""
|
| 89 |
+
return (
|
| 90 |
+
f"Schema:\n{ctx.schema}\n\n"
|
| 91 |
+
f"Question: {ctx.question}\n\n"
|
| 92 |
+
f"Previous SQL:\n{ctx.failing_sql}\n\n"
|
| 93 |
+
f"Error: {ctx.error_message}"
|
| 94 |
+
f"{token_hint}\n\n"
|
| 95 |
+
"Fix the syntax error. Check for typos, missing commas, unmatched parentheses."
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _tmpl_change_dialect(ctx: RepairContext) -> str:
|
| 100 |
+
return (
|
| 101 |
+
f"Schema:\n{ctx.schema}\n\n"
|
| 102 |
+
f"Question: {ctx.question}\n\n"
|
| 103 |
+
f"Previous SQL:\n{ctx.failing_sql}\n\n"
|
| 104 |
+
f"Error: {ctx.error_message}\n\n"
|
| 105 |
+
"The SQL uses functions or syntax not supported by SQLite. "
|
| 106 |
+
"Rewrite using SQLite-compatible alternatives."
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _tmpl_relax_filter(ctx: RepairContext) -> str:
|
| 111 |
+
return (
|
| 112 |
+
f"Schema:\n{ctx.schema}\n\n"
|
| 113 |
+
f"Question: {ctx.question}\n\n"
|
| 114 |
+
f"Previous SQL:\n{ctx.failing_sql}\n\n"
|
| 115 |
+
f"Error: {ctx.error_message}\n\n"
|
| 116 |
+
"Review and relax the WHERE/HAVING conditions. "
|
| 117 |
+
"Check date formats, value ranges, and filter logic."
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
_STRATEGIES: dict[RepairAction, RepairStrategy] = {
|
| 122 |
+
RepairAction.REWRITE_FULL: RepairStrategy(
|
| 123 |
+
action=RepairAction.REWRITE_FULL,
|
| 124 |
+
name="Full Rewrite",
|
| 125 |
+
system_suffix=(
|
| 126 |
+
"\n\nIMPORTANT: The previous SQL attempt was fundamentally flawed. "
|
| 127 |
+
"Discard it entirely and write a new query from scratch based only on "
|
| 128 |
+
"the schema and question. Do NOT try to patch the previous SQL."
|
| 129 |
+
),
|
| 130 |
+
user_template=_tmpl_rewrite_full,
|
| 131 |
+
),
|
| 132 |
+
RepairAction.FIX_COLUMN: RepairStrategy(
|
| 133 |
+
action=RepairAction.FIX_COLUMN,
|
| 134 |
+
name="Fix Column",
|
| 135 |
+
system_suffix=(
|
| 136 |
+
"\n\nIMPORTANT: The previous SQL referenced a wrong column name. "
|
| 137 |
+
"Carefully check the schema for the exact column names in each table. "
|
| 138 |
+
"Pay attention to singular vs plural, underscores, and exact spelling."
|
| 139 |
+
),
|
| 140 |
+
user_template=_tmpl_fix_column,
|
| 141 |
+
),
|
| 142 |
+
RepairAction.FIX_TABLE: RepairStrategy(
|
| 143 |
+
action=RepairAction.FIX_TABLE,
|
| 144 |
+
name="Fix Table",
|
| 145 |
+
system_suffix=(
|
| 146 |
+
"\n\nIMPORTANT: The previous SQL referenced a wrong table name or had "
|
| 147 |
+
"incorrect JOIN relationships. Check the schema for exact table names "
|
| 148 |
+
"and foreign key relationships."
|
| 149 |
+
),
|
| 150 |
+
user_template=_tmpl_fix_table,
|
| 151 |
+
),
|
| 152 |
+
RepairAction.ADD_GROUPBY: RepairStrategy(
|
| 153 |
+
action=RepairAction.ADD_GROUPBY,
|
| 154 |
+
name="Fix GROUP BY",
|
| 155 |
+
system_suffix=(
|
| 156 |
+
"\n\nIMPORTANT: The previous SQL has an aggregation error. Every column "
|
| 157 |
+
"in SELECT that is not inside an aggregate function (COUNT, SUM, AVG, etc.) "
|
| 158 |
+
"MUST appear in the GROUP BY clause. Check all selected columns."
|
| 159 |
+
),
|
| 160 |
+
user_template=_tmpl_add_groupby,
|
| 161 |
+
),
|
| 162 |
+
RepairAction.REWRITE_CTE: RepairStrategy(
|
| 163 |
+
action=RepairAction.REWRITE_CTE,
|
| 164 |
+
name="Rewrite CTE/Subquery",
|
| 165 |
+
system_suffix=(
|
| 166 |
+
"\n\nIMPORTANT: The previous SQL had issues with CTEs or subqueries. "
|
| 167 |
+
"Restructure the query β consider using WITH clauses for clarity, or "
|
| 168 |
+
"flatten nested subqueries. Ensure CTE column names are explicitly defined if needed."
|
| 169 |
+
),
|
| 170 |
+
user_template=_tmpl_rewrite_cte,
|
| 171 |
+
),
|
| 172 |
+
RepairAction.FIX_SYNTAX: RepairStrategy(
|
| 173 |
+
action=RepairAction.FIX_SYNTAX,
|
| 174 |
+
name="Fix Syntax",
|
| 175 |
+
system_suffix=(
|
| 176 |
+
"\n\nIMPORTANT: The previous SQL has a syntax error. Check for: "
|
| 177 |
+
"missing commas, unmatched parentheses, misspelled keywords, "
|
| 178 |
+
"incorrect operator usage, missing AS aliases."
|
| 179 |
+
),
|
| 180 |
+
user_template=_tmpl_fix_syntax,
|
| 181 |
+
),
|
| 182 |
+
RepairAction.CHANGE_DIALECT: RepairStrategy(
|
| 183 |
+
action=RepairAction.CHANGE_DIALECT,
|
| 184 |
+
name="Fix Dialect",
|
| 185 |
+
system_suffix=(
|
| 186 |
+
"\n\nIMPORTANT: The previous SQL used functions or syntax not available in SQLite. "
|
| 187 |
+
"Key SQLite rules:\n"
|
| 188 |
+
"- Use strftime() for date formatting, NOT DATE_FORMAT or EXTRACT\n"
|
| 189 |
+
"- No FULL OUTER JOIN or RIGHT JOIN β use LEFT JOIN with UNION\n"
|
| 190 |
+
"- Use CAST(x AS INTEGER), not CONVERT()\n"
|
| 191 |
+
"- No ILIKE β use LIKE (case-insensitive by default for ASCII)\n"
|
| 192 |
+
"- String concatenation uses || not CONCAT()\n"
|
| 193 |
+
"- No LIMIT inside subqueries with IN (use CTE instead)"
|
| 194 |
+
),
|
| 195 |
+
user_template=_tmpl_change_dialect,
|
| 196 |
+
),
|
| 197 |
+
RepairAction.RELAX_FILTER: RepairStrategy(
|
| 198 |
+
action=RepairAction.RELAX_FILTER,
|
| 199 |
+
name="Relax Filter",
|
| 200 |
+
system_suffix=(
|
| 201 |
+
"\n\nIMPORTANT: The previous SQL may have overly restrictive WHERE conditions, "
|
| 202 |
+
"incorrect date ranges, or wrong filter values causing empty results or errors. "
|
| 203 |
+
"Review the filter conditions and broaden them to capture the intended data."
|
| 204 |
+
),
|
| 205 |
+
user_template=_tmpl_relax_filter,
|
| 206 |
+
),
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def get_repair_system_suffix(action: RepairAction) -> str:
|
| 211 |
+
return _STRATEGIES[action].system_suffix
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def build_repair_user_message(action: RepairAction, ctx: RepairContext) -> str:
|
| 215 |
+
return _STRATEGIES[action].user_template(ctx)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def get_repair_name(action: RepairAction) -> str:
|
| 219 |
+
return _STRATEGIES[action].name
|
backend/rl/types.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RL type definitions and feature engineering.
|
| 3 |
+
|
| 4 |
+
Mirrors the TypeScript types.ts exactly:
|
| 5 |
+
- 8 error classes, 8 repair actions
|
| 6 |
+
- FEATURE_DIM = 20
|
| 7 |
+
- featurize() builds the state vector
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from enum import IntEnum
|
| 13 |
+
from typing import Optional, List, Dict, Any
|
| 14 |
+
from pydantic import BaseModel
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# βββ Error Taxonomy βββββββββββββββββββββββββββββββββββββββββββββ
|
| 18 |
+
|
| 19 |
+
class ErrorClass(IntEnum):
|
| 20 |
+
NO_SUCH_COLUMN = 0
|
| 21 |
+
NO_SUCH_TABLE = 1
|
| 22 |
+
SYNTAX_ERROR = 2
|
| 23 |
+
AMBIGUOUS_COLUMN = 3
|
| 24 |
+
DATATYPE_MISMATCH = 4
|
| 25 |
+
NO_SUCH_FUNCTION = 5
|
| 26 |
+
AGGREGATION_ERROR = 6
|
| 27 |
+
OTHER = 7
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
ERROR_CLASS_NAMES: Dict[ErrorClass, str] = {
|
| 31 |
+
ErrorClass.NO_SUCH_COLUMN: "no_such_column",
|
| 32 |
+
ErrorClass.NO_SUCH_TABLE: "no_such_table",
|
| 33 |
+
ErrorClass.SYNTAX_ERROR: "syntax_error",
|
| 34 |
+
ErrorClass.AMBIGUOUS_COLUMN: "ambiguous_column",
|
| 35 |
+
ErrorClass.DATATYPE_MISMATCH: "datatype_mismatch",
|
| 36 |
+
ErrorClass.NO_SUCH_FUNCTION: "no_such_function",
|
| 37 |
+
ErrorClass.AGGREGATION_ERROR: "aggregation_error",
|
| 38 |
+
ErrorClass.OTHER: "other",
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
NUM_ERROR_CLASSES = 8
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# βββ Repair Actions βββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
+
|
| 46 |
+
class RepairAction(IntEnum):
|
| 47 |
+
REWRITE_FULL = 0
|
| 48 |
+
FIX_COLUMN = 1
|
| 49 |
+
FIX_TABLE = 2
|
| 50 |
+
ADD_GROUPBY = 3
|
| 51 |
+
REWRITE_CTE = 4
|
| 52 |
+
FIX_SYNTAX = 5
|
| 53 |
+
CHANGE_DIALECT = 6
|
| 54 |
+
RELAX_FILTER = 7
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
REPAIR_ACTION_NAMES: Dict[RepairAction, str] = {
|
| 58 |
+
RepairAction.REWRITE_FULL: "rewrite_full",
|
| 59 |
+
RepairAction.FIX_COLUMN: "fix_column",
|
| 60 |
+
RepairAction.FIX_TABLE: "fix_table",
|
| 61 |
+
RepairAction.ADD_GROUPBY: "add_groupby",
|
| 62 |
+
RepairAction.REWRITE_CTE: "rewrite_cte",
|
| 63 |
+
RepairAction.FIX_SYNTAX: "fix_syntax",
|
| 64 |
+
RepairAction.CHANGE_DIALECT: "change_dialect",
|
| 65 |
+
RepairAction.RELAX_FILTER: "relax_filter",
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Inverse map: name β enum
|
| 69 |
+
REPAIR_ACTION_BY_NAME: Dict[str, RepairAction] = {v: k for k, v in REPAIR_ACTION_NAMES.items()}
|
| 70 |
+
|
| 71 |
+
NUM_ACTIONS = 8
|
| 72 |
+
|
| 73 |
+
# Feature vector:
|
| 74 |
+
# [0..7] error class one-hot (8)
|
| 75 |
+
# [8] attempt / 5.0 (1)
|
| 76 |
+
# [9..16] prev action one-hot (8)
|
| 77 |
+
# [17] error_changed (1)
|
| 78 |
+
# [18] consec_count / 5.0 (1)
|
| 79 |
+
# [19] bias = 1.0 (1)
|
| 80 |
+
# total = 20
|
| 81 |
+
FEATURE_DIM = 20
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# βββ State ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 85 |
+
|
| 86 |
+
class RLState(BaseModel):
|
| 87 |
+
error_class: ErrorClass
|
| 88 |
+
attempt_number: int # 1-indexed
|
| 89 |
+
previous_action: Optional[RepairAction] = None
|
| 90 |
+
error_changed: bool = False
|
| 91 |
+
consecutive_same_error: int = 1
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def featurize(state: RLState) -> List[float]:
|
| 95 |
+
"""Build the 20-dimensional feature vector from an RLState."""
|
| 96 |
+
x = [0.0] * FEATURE_DIM
|
| 97 |
+
|
| 98 |
+
# Error class one-hot [0..7]
|
| 99 |
+
x[state.error_class] = 1.0
|
| 100 |
+
|
| 101 |
+
# Attempt number normalized [8]
|
| 102 |
+
x[8] = state.attempt_number / 5.0
|
| 103 |
+
|
| 104 |
+
# Previous action one-hot [9..16]
|
| 105 |
+
if state.previous_action is not None:
|
| 106 |
+
x[9 + int(state.previous_action)] = 1.0
|
| 107 |
+
|
| 108 |
+
# Error changed flag [17]
|
| 109 |
+
x[17] = 1.0 if state.error_changed else 0.0
|
| 110 |
+
|
| 111 |
+
# Consecutive same error normalized [18]
|
| 112 |
+
x[18] = min(state.consecutive_same_error, 5) / 5.0
|
| 113 |
+
|
| 114 |
+
# Bias term [19]
|
| 115 |
+
x[19] = 1.0
|
| 116 |
+
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# βββ Experience / Episode ββββββββββββββββββββββββββββββββββββββββ
|
| 121 |
+
|
| 122 |
+
class EpisodeStep(BaseModel):
|
| 123 |
+
state: RLState
|
| 124 |
+
featurized: List[float]
|
| 125 |
+
action: RepairAction
|
| 126 |
+
reward: float
|
| 127 |
+
error_message: str
|
| 128 |
+
sql: str
|
| 129 |
+
success: bool
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class Episode(BaseModel):
|
| 133 |
+
id: str
|
| 134 |
+
question: str
|
| 135 |
+
steps: List[EpisodeStep]
|
| 136 |
+
total_reward: float
|
| 137 |
+
success: bool
|
| 138 |
+
timestamp: float
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class Experience(BaseModel):
|
| 142 |
+
state: List[float]
|
| 143 |
+
action: RepairAction
|
| 144 |
+
reward: float
|
| 145 |
+
next_state: Optional[List[float]] = None
|
| 146 |
+
done: bool
|
| 147 |
+
timestamp: float
|
| 148 |
+
metadata: Dict[str, Any]
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# βββ Metrics ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 152 |
+
|
| 153 |
+
class RLMetrics(BaseModel):
|
| 154 |
+
total_episodes: int
|
| 155 |
+
total_steps: int
|
| 156 |
+
cumulative_reward: float
|
| 157 |
+
success_rate: float
|
| 158 |
+
avg_attempts: float
|
| 159 |
+
action_distribution: Dict[str, int]
|
| 160 |
+
error_distribution: Dict[str, int]
|
| 161 |
+
reward_history: List[float]
|
frontend/index.html
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8" />
|
| 5 |
+
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
+
<title>SQL Agent OpenEnv β RL Environment</title>
|
| 8 |
+
<meta name="description" content="SQL Agent with Reinforcement Learning and GEPA prompt evolution" />
|
| 9 |
+
</head>
|
| 10 |
+
<body>
|
| 11 |
+
<div id="root"></div>
|
| 12 |
+
<script type="module" src="/src/main.tsx"></script>
|
| 13 |
+
</body>
|
| 14 |
+
</html>
|
frontend/package-lock.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
frontend/package.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "sql-openenv-ui",
|
| 3 |
+
"private": true,
|
| 4 |
+
"version": "0.1.0",
|
| 5 |
+
"type": "module",
|
| 6 |
+
"scripts": {
|
| 7 |
+
"dev": "vite --port 5173",
|
| 8 |
+
"build": "vite build",
|
| 9 |
+
"preview": "vite preview"
|
| 10 |
+
},
|
| 11 |
+
"dependencies": {
|
| 12 |
+
"react": "^19.0.0",
|
| 13 |
+
"react-dom": "^19.0.0",
|
| 14 |
+
"framer-motion": "^11.0.0",
|
| 15 |
+
"lucide-react": "^0.400.0",
|
| 16 |
+
"recharts": "^2.12.0",
|
| 17 |
+
"zustand": "^4.5.0",
|
| 18 |
+
"react-markdown": "^9.0.0"
|
| 19 |
+
},
|
| 20 |
+
"devDependencies": {
|
| 21 |
+
"@types/react": "^19.0.0",
|
| 22 |
+
"@types/react-dom": "^19.0.0",
|
| 23 |
+
"@vitejs/plugin-react": "^4.3.0",
|
| 24 |
+
"typescript": "^5.5.0",
|
| 25 |
+
"vite": "^5.4.0",
|
| 26 |
+
"tailwindcss": "^3.4.0",
|
| 27 |
+
"autoprefixer": "^10.4.0",
|
| 28 |
+
"postcss": "^8.4.0"
|
| 29 |
+
}
|
| 30 |
+
}
|
frontend/postcss.config.js
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export default {
|
| 2 |
+
plugins: {
|
| 3 |
+
tailwindcss: {},
|
| 4 |
+
autoprefixer: {},
|
| 5 |
+
},
|
| 6 |
+
}
|
frontend/src/App.tsx
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useState, useEffect } from 'react'
|
| 2 |
+
import { motion, AnimatePresence } from 'framer-motion'
|
| 3 |
+
import { MessageSquare, Target, GitFork, X } from 'lucide-react'
|
| 4 |
+
|
| 5 |
+
import { Header } from './components/Header'
|
| 6 |
+
import { LeftSidebar } from './components/LeftSidebar'
|
| 7 |
+
import { ChatPanel } from './components/ChatPanel'
|
| 8 |
+
import { BenchmarkPanel } from './components/BenchmarkPanel'
|
| 9 |
+
import { ERDiagram } from './components/ERDiagram'
|
| 10 |
+
import { RightSidebar } from './components/RightSidebar'
|
| 11 |
+
import { useStore } from './store/useStore'
|
| 12 |
+
import { fetchInit } from './lib/api'
|
| 13 |
+
|
| 14 |
+
type Tab = 'chat' | 'benchmark' | 'er'
|
| 15 |
+
|
| 16 |
+
const TABS: { id: Tab; label: string; icon: React.ReactNode }[] = [
|
| 17 |
+
{ id: 'chat', label: 'Chat', icon: <MessageSquare size={12} /> },
|
| 18 |
+
{ id: 'benchmark', label: 'Benchmark', icon: <Target size={12} /> },
|
| 19 |
+
{ id: 'er', label: 'ER Diagram', icon: <GitFork size={12} /> },
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
export default function App() {
|
| 23 |
+
const [activeTab, setActiveTab] = useState<Tab>('chat')
|
| 24 |
+
const [leftOpen, setLeftOpen] = useState(false)
|
| 25 |
+
const [rightOpen, setRightOpen] = useState(false)
|
| 26 |
+
|
| 27 |
+
const { theme, setDbSeeded, setTables, setSchemaGraph } = useStore()
|
| 28 |
+
|
| 29 |
+
// Apply theme on mount / change
|
| 30 |
+
useEffect(() => {
|
| 31 |
+
document.documentElement.setAttribute('data-theme', theme)
|
| 32 |
+
}, [theme])
|
| 33 |
+
|
| 34 |
+
// Restore theme from storage on mount
|
| 35 |
+
useEffect(() => {
|
| 36 |
+
try {
|
| 37 |
+
const saved = localStorage.getItem('theme') as 'dark' | 'light' | null
|
| 38 |
+
if (saved) {
|
| 39 |
+
document.documentElement.setAttribute('data-theme', saved)
|
| 40 |
+
useStore.setState({ theme: saved })
|
| 41 |
+
}
|
| 42 |
+
} catch { /* noop */ }
|
| 43 |
+
}, [])
|
| 44 |
+
|
| 45 |
+
// Fetch init data
|
| 46 |
+
useEffect(() => {
|
| 47 |
+
fetchInit()
|
| 48 |
+
.then((d) => {
|
| 49 |
+
setDbSeeded(true)
|
| 50 |
+
setTables(d.tables)
|
| 51 |
+
// Lazy-load schema graph
|
| 52 |
+
fetch('/api/schema-graph')
|
| 53 |
+
.then((r) => r.json())
|
| 54 |
+
.then((g) => setSchemaGraph(g))
|
| 55 |
+
.catch(() => { /* noop */ })
|
| 56 |
+
})
|
| 57 |
+
.catch(() => { /* noop */ })
|
| 58 |
+
}, [setDbSeeded, setTables, setSchemaGraph])
|
| 59 |
+
|
| 60 |
+
// Close mobile sidebars on tab change
|
| 61 |
+
useEffect(() => {
|
| 62 |
+
setLeftOpen(false)
|
| 63 |
+
setRightOpen(false)
|
| 64 |
+
}, [activeTab])
|
| 65 |
+
|
| 66 |
+
return (
|
| 67 |
+
<div
|
| 68 |
+
className="h-screen flex flex-col overflow-hidden theme-bg-primary theme-text-primary"
|
| 69 |
+
style={{ fontFamily: 'ui-monospace,"SF Mono",Consolas,"Liberation Mono",monospace' }}
|
| 70 |
+
>
|
| 71 |
+
<Header
|
| 72 |
+
onToggleLeft={() => { setLeftOpen((v) => !v); setRightOpen(false) }}
|
| 73 |
+
onToggleRight={() => { setRightOpen((v) => !v); setLeftOpen(false) }}
|
| 74 |
+
/>
|
| 75 |
+
|
| 76 |
+
<div className="flex flex-1 overflow-hidden relative">
|
| 77 |
+
{/* Overlay backdrop (mobile) */}
|
| 78 |
+
{(leftOpen || rightOpen) && (
|
| 79 |
+
<div
|
| 80 |
+
className="fixed inset-0 bg-black/50 z-30 lg:hidden"
|
| 81 |
+
onClick={() => { setLeftOpen(false); setRightOpen(false) }}
|
| 82 |
+
/>
|
| 83 |
+
)}
|
| 84 |
+
|
| 85 |
+
{/* LEFT SIDEBAR */}
|
| 86 |
+
<aside
|
| 87 |
+
className={`
|
| 88 |
+
fixed top-[53px] bottom-0 left-0 z-40 w-60 border-r theme-border flex flex-col overflow-y-auto
|
| 89 |
+
transition-transform duration-200 ease-out
|
| 90 |
+
lg:static lg:w-60 lg:shrink-0 lg:translate-x-0 lg:z-auto
|
| 91 |
+
${leftOpen ? 'translate-x-0' : '-translate-x-full'}
|
| 92 |
+
`}
|
| 93 |
+
style={{ background: 'var(--bg-secondary)' }}
|
| 94 |
+
>
|
| 95 |
+
<div className="flex items-center justify-between px-4 pt-3 pb-1 lg:hidden">
|
| 96 |
+
<span className="text-[10px] font-semibold text-gray-500 uppercase tracking-wider">
|
| 97 |
+
Dataset & Tasks
|
| 98 |
+
</span>
|
| 99 |
+
<button
|
| 100 |
+
onClick={() => setLeftOpen(false)}
|
| 101 |
+
className="p-1 rounded hover:bg-white/5 text-gray-500"
|
| 102 |
+
>
|
| 103 |
+
<X size={14} />
|
| 104 |
+
</button>
|
| 105 |
+
</div>
|
| 106 |
+
<div className="flex-1 px-4 py-3">
|
| 107 |
+
<LeftSidebar />
|
| 108 |
+
</div>
|
| 109 |
+
</aside>
|
| 110 |
+
|
| 111 |
+
{/* CENTER: Tabbed panel */}
|
| 112 |
+
<main className="flex-1 flex flex-col overflow-hidden min-w-0">
|
| 113 |
+
{/* Tab bar */}
|
| 114 |
+
<div
|
| 115 |
+
className="flex items-center gap-1 px-2 sm:px-4 py-2.5 border-b theme-border shrink-0 overflow-x-auto scrollbar-none"
|
| 116 |
+
style={{ background: 'var(--bg-secondary)' }}
|
| 117 |
+
>
|
| 118 |
+
{TABS.map((tab) => (
|
| 119 |
+
<button
|
| 120 |
+
key={tab.id}
|
| 121 |
+
onClick={() => setActiveTab(tab.id)}
|
| 122 |
+
className={`flex items-center gap-1.5 px-2.5 sm:px-3 py-1.5 rounded-lg text-xs font-medium transition-all whitespace-nowrap shrink-0 ${
|
| 123 |
+
activeTab === tab.id
|
| 124 |
+
? 'bg-violet-600/20 text-violet-300 border border-violet-500/30'
|
| 125 |
+
: 'text-gray-500 hover:text-gray-300 hover:bg-white/5 border border-transparent'
|
| 126 |
+
}`}
|
| 127 |
+
>
|
| 128 |
+
{tab.icon}
|
| 129 |
+
<span>{tab.label}</span>
|
| 130 |
+
</button>
|
| 131 |
+
))}
|
| 132 |
+
</div>
|
| 133 |
+
|
| 134 |
+
{/* Tab content */}
|
| 135 |
+
<div className="flex-1 overflow-hidden relative">
|
| 136 |
+
<AnimatePresence mode="wait">
|
| 137 |
+
<motion.div
|
| 138 |
+
key={activeTab}
|
| 139 |
+
initial={{ opacity: 0, y: 4 }}
|
| 140 |
+
animate={{ opacity: 1, y: 0 }}
|
| 141 |
+
exit={{ opacity: 0 }}
|
| 142 |
+
transition={{ duration: 0.15 }}
|
| 143 |
+
className="absolute inset-0 flex flex-col overflow-hidden"
|
| 144 |
+
>
|
| 145 |
+
{activeTab === 'chat' && <ChatPanel />}
|
| 146 |
+
{activeTab === 'benchmark' && <BenchmarkPanel />}
|
| 147 |
+
{activeTab === 'er' && <ERDiagram />}
|
| 148 |
+
</motion.div>
|
| 149 |
+
</AnimatePresence>
|
| 150 |
+
</div>
|
| 151 |
+
</main>
|
| 152 |
+
|
| 153 |
+
{/* RIGHT SIDEBAR */}
|
| 154 |
+
<aside
|
| 155 |
+
className={`
|
| 156 |
+
fixed top-[53px] bottom-0 right-0 z-40 w-72 border-l theme-border flex flex-col overflow-hidden
|
| 157 |
+
transition-transform duration-200 ease-out
|
| 158 |
+
lg:static lg:w-72 lg:shrink-0 lg:translate-x-0 lg:z-auto
|
| 159 |
+
${rightOpen ? 'translate-x-0' : 'translate-x-full'}
|
| 160 |
+
`}
|
| 161 |
+
style={{ background: 'var(--bg-secondary)' }}
|
| 162 |
+
>
|
| 163 |
+
<div className="flex items-center justify-between px-4 pt-3 pb-1 lg:hidden">
|
| 164 |
+
<span className="text-[10px] font-semibold text-gray-500 uppercase tracking-wider">
|
| 165 |
+
GEPA & RL
|
| 166 |
+
</span>
|
| 167 |
+
<button
|
| 168 |
+
onClick={() => setRightOpen(false)}
|
| 169 |
+
className="p-1 rounded hover:bg-white/5 text-gray-500"
|
| 170 |
+
>
|
| 171 |
+
<X size={14} />
|
| 172 |
+
</button>
|
| 173 |
+
</div>
|
| 174 |
+
<RightSidebar />
|
| 175 |
+
</aside>
|
| 176 |
+
</div>
|
| 177 |
+
</div>
|
| 178 |
+
)
|
| 179 |
+
}
|
frontend/src/components/BenchmarkPanel.tsx
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useState, useCallback } from 'react'
|
| 2 |
+
import { motion, AnimatePresence } from 'framer-motion'
|
| 3 |
+
import {
|
| 4 |
+
Target, Play, Loader2, CheckCircle2, XCircle,
|
| 5 |
+
ChevronDown, RotateCcw, Zap,
|
| 6 |
+
} from 'lucide-react'
|
| 7 |
+
import { useStore } from '../store/useStore'
|
| 8 |
+
import { streamBenchmark } from '../lib/api'
|
| 9 |
+
import type { BenchmarkResult, Difficulty } from '../lib/types'
|
| 10 |
+
|
| 11 |
+
const DIFFICULTY_TABS: { id: Difficulty; label: string }[] = [
|
| 12 |
+
{ id: 'easy', label: 'Easy' },
|
| 13 |
+
{ id: 'medium', label: 'Medium' },
|
| 14 |
+
{ id: 'hard', label: 'Hard' },
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
function QueryRow({
|
| 18 |
+
result,
|
| 19 |
+
isActive,
|
| 20 |
+
isExpanded,
|
| 21 |
+
onToggleExpand,
|
| 22 |
+
onRunSingle,
|
| 23 |
+
isRunning,
|
| 24 |
+
dbSeeded,
|
| 25 |
+
}: {
|
| 26 |
+
result: BenchmarkResult
|
| 27 |
+
isActive: boolean
|
| 28 |
+
isExpanded: boolean
|
| 29 |
+
onToggleExpand: () => void
|
| 30 |
+
onRunSingle: () => void
|
| 31 |
+
isRunning: boolean
|
| 32 |
+
dbSeeded: boolean
|
| 33 |
+
}) {
|
| 34 |
+
const statusIcon = () => {
|
| 35 |
+
switch (result.status) {
|
| 36 |
+
case 'pending': return <span className="w-2 h-2 rounded-full bg-gray-600 shrink-0" />
|
| 37 |
+
case 'running': return <Loader2 size={12} className="text-violet-400 animate-spin shrink-0" />
|
| 38 |
+
case 'pass': return <CheckCircle2 size={12} className="text-green-400 shrink-0" />
|
| 39 |
+
case 'fail': return <XCircle size={12} className="text-red-400 shrink-0" />
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
const difficultyColor =
|
| 44 |
+
result.difficulty === 'hard'
|
| 45 |
+
? 'text-red-400 bg-red-500/10 border-red-500/25'
|
| 46 |
+
: result.difficulty === 'medium'
|
| 47 |
+
? 'text-amber-400 bg-amber-500/10 border-amber-500/25'
|
| 48 |
+
: 'text-blue-400 bg-blue-500/10 border-blue-500/25'
|
| 49 |
+
|
| 50 |
+
return (
|
| 51 |
+
<div
|
| 52 |
+
className={`rounded-xl border transition-all duration-150 ${
|
| 53 |
+
isActive
|
| 54 |
+
? 'border-violet-500/40 bg-violet-500/5'
|
| 55 |
+
: 'border-white/5 bg-white/[0.02] hover:bg-white/[0.04]'
|
| 56 |
+
}`}
|
| 57 |
+
>
|
| 58 |
+
<div
|
| 59 |
+
className="flex items-start gap-2 px-3 py-2.5 cursor-pointer"
|
| 60 |
+
onClick={onToggleExpand}
|
| 61 |
+
>
|
| 62 |
+
<div className="mt-0.5 shrink-0">{statusIcon()}</div>
|
| 63 |
+
<div className="flex-1 min-w-0">
|
| 64 |
+
<div className="flex items-center gap-2 mb-0.5 flex-wrap">
|
| 65 |
+
<span className="text-[10px] font-mono text-gray-600">{result.id}</span>
|
| 66 |
+
<span className={`text-[9px] font-semibold px-1.5 py-0.5 rounded-full border ${difficultyColor}`}>
|
| 67 |
+
{result.difficulty}
|
| 68 |
+
</span>
|
| 69 |
+
{result.score !== null && (
|
| 70 |
+
<span className={`text-[10px] font-mono font-bold ${result.status === 'pass' ? 'text-green-400' : 'text-red-400'}`}>
|
| 71 |
+
{result.score.toFixed(2)}
|
| 72 |
+
</span>
|
| 73 |
+
)}
|
| 74 |
+
{result.attempts !== null && (
|
| 75 |
+
<span className="text-[9px] text-gray-600 font-mono">
|
| 76 |
+
{result.attempts} attempt{result.attempts !== 1 ? 's' : ''}
|
| 77 |
+
</span>
|
| 78 |
+
)}
|
| 79 |
+
</div>
|
| 80 |
+
<div className="text-xs text-gray-300 leading-relaxed line-clamp-2">
|
| 81 |
+
{result.question}
|
| 82 |
+
</div>
|
| 83 |
+
{result.reason && result.status !== 'pending' && (
|
| 84 |
+
<div className={`text-[10px] mt-1 ${result.status === 'pass' ? 'text-green-500/70' : 'text-red-400/70'}`}>
|
| 85 |
+
{result.reason.length > 120 ? result.reason.slice(0, 120) + 'β¦' : result.reason}
|
| 86 |
+
</div>
|
| 87 |
+
)}
|
| 88 |
+
</div>
|
| 89 |
+
<div className="flex items-center gap-1.5 shrink-0">
|
| 90 |
+
{result.status === 'pending' && dbSeeded && !isRunning && (
|
| 91 |
+
<button
|
| 92 |
+
onClick={(e) => { e.stopPropagation(); onRunSingle() }}
|
| 93 |
+
className="p-1 rounded-lg hover:bg-white/10 transition-colors"
|
| 94 |
+
title="Run this query"
|
| 95 |
+
>
|
| 96 |
+
<Play size={10} className="text-gray-500 hover:text-violet-400" />
|
| 97 |
+
</button>
|
| 98 |
+
)}
|
| 99 |
+
<ChevronDown
|
| 100 |
+
size={11}
|
| 101 |
+
className={`text-gray-600 transition-transform duration-150 ${isExpanded ? 'rotate-180' : ''}`}
|
| 102 |
+
/>
|
| 103 |
+
</div>
|
| 104 |
+
</div>
|
| 105 |
+
|
| 106 |
+
{/* Expanded detail */}
|
| 107 |
+
<AnimatePresence>
|
| 108 |
+
{isExpanded && (
|
| 109 |
+
<motion.div
|
| 110 |
+
initial={{ height: 0, opacity: 0 }}
|
| 111 |
+
animate={{ height: 'auto', opacity: 1 }}
|
| 112 |
+
exit={{ height: 0, opacity: 0 }}
|
| 113 |
+
transition={{ duration: 0.15 }}
|
| 114 |
+
className="overflow-hidden"
|
| 115 |
+
>
|
| 116 |
+
<div className="px-3 pb-3 flex flex-col gap-2 border-t border-white/5 pt-2">
|
| 117 |
+
<p className="text-xs text-gray-400 leading-relaxed">{result.question}</p>
|
| 118 |
+
|
| 119 |
+
{result.sql && (
|
| 120 |
+
<div>
|
| 121 |
+
<div className="text-[10px] text-gray-600 mb-1 font-semibold uppercase tracking-wider">
|
| 122 |
+
Generated SQL
|
| 123 |
+
</div>
|
| 124 |
+
<pre className="text-[10px] font-mono text-violet-200/70 bg-black/40 rounded-lg p-2.5 border border-white/5 whitespace-pre-wrap leading-relaxed max-h-40 overflow-y-auto">
|
| 125 |
+
{result.sql}
|
| 126 |
+
</pre>
|
| 127 |
+
</div>
|
| 128 |
+
)}
|
| 129 |
+
|
| 130 |
+
{(result.refRowCount !== null || result.reason) && (
|
| 131 |
+
<div className="flex flex-col gap-1.5">
|
| 132 |
+
{result.refRowCount !== null && (
|
| 133 |
+
<div className="flex items-center gap-3 text-[10px] font-mono">
|
| 134 |
+
<span className="text-gray-600">reference:</span>
|
| 135 |
+
<span className="text-blue-400">{result.refRowCount} rows</span>
|
| 136 |
+
<span className="text-gray-600">agent:</span>
|
| 137 |
+
<span className={
|
| 138 |
+
result.agentRowCount === result.refRowCount
|
| 139 |
+
? 'text-green-400'
|
| 140 |
+
: result.agentRowCount === 0
|
| 141 |
+
? 'text-red-400'
|
| 142 |
+
: 'text-amber-400'
|
| 143 |
+
}>
|
| 144 |
+
{result.agentRowCount ?? 0} rows
|
| 145 |
+
</span>
|
| 146 |
+
</div>
|
| 147 |
+
)}
|
| 148 |
+
{result.reason && (
|
| 149 |
+
<div className={`text-[10px] leading-relaxed ${result.status === 'pass' ? 'text-green-400/80' : 'text-red-400/80'}`}>
|
| 150 |
+
{result.reason}
|
| 151 |
+
</div>
|
| 152 |
+
)}
|
| 153 |
+
</div>
|
| 154 |
+
)}
|
| 155 |
+
|
| 156 |
+
{result.status !== 'pending' && result.status !== 'running' && !isRunning && dbSeeded && (
|
| 157 |
+
<button
|
| 158 |
+
onClick={(e) => { e.stopPropagation(); onRunSingle() }}
|
| 159 |
+
className="flex items-center gap-1 text-[10px] text-violet-400 hover:text-violet-300 transition-colors self-start mt-1"
|
| 160 |
+
>
|
| 161 |
+
<RotateCcw size={9} />
|
| 162 |
+
Re-run
|
| 163 |
+
</button>
|
| 164 |
+
)}
|
| 165 |
+
</div>
|
| 166 |
+
</motion.div>
|
| 167 |
+
)}
|
| 168 |
+
</AnimatePresence>
|
| 169 |
+
</div>
|
| 170 |
+
)
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
export function BenchmarkPanel() {
|
| 174 |
+
const {
|
| 175 |
+
benchmarkResults, isBenchmarking, overallScore,
|
| 176 |
+
activeBenchmarkId, dbSeeded,
|
| 177 |
+
setIsBenchmarking, updateBenchmarkResult, setOverallScore,
|
| 178 |
+
setActiveBenchmarkId, resetBenchmark,
|
| 179 |
+
taskDifficulty, setTaskDifficulty,
|
| 180 |
+
} = useStore()
|
| 181 |
+
|
| 182 |
+
const [expandedIds, setExpandedIds] = useState<Set<string>>(new Set())
|
| 183 |
+
|
| 184 |
+
const toggleExpand = (id: string) => {
|
| 185 |
+
setExpandedIds((prev) => {
|
| 186 |
+
const next = new Set(prev)
|
| 187 |
+
if (next.has(id)) next.delete(id)
|
| 188 |
+
else next.add(id)
|
| 189 |
+
return next
|
| 190 |
+
})
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
const runBenchmark = useCallback(
|
| 194 |
+
async (queryIds?: string[]) => {
|
| 195 |
+
if (isBenchmarking) return
|
| 196 |
+
setIsBenchmarking(true)
|
| 197 |
+
|
| 198 |
+
const targetIds = queryIds ?? benchmarkResults.map((r) => r.id)
|
| 199 |
+
for (const id of targetIds) {
|
| 200 |
+
const existing = benchmarkResults.find((r) => r.id === id)
|
| 201 |
+
if (existing) {
|
| 202 |
+
updateBenchmarkResult({ ...existing, status: 'running', score: null, reason: null, sql: null })
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
try {
|
| 207 |
+
for await (const event of streamBenchmark(taskDifficulty, queryIds)) {
|
| 208 |
+
if (event.type === 'query_start') {
|
| 209 |
+
setActiveBenchmarkId(event.id as string)
|
| 210 |
+
const existing = benchmarkResults.find((r) => r.id === event.id)
|
| 211 |
+
if (existing) updateBenchmarkResult({ ...existing, status: 'running' })
|
| 212 |
+
} else if (event.type === 'query_result') {
|
| 213 |
+
const existing = benchmarkResults.find((r) => r.id === event.id)
|
| 214 |
+
if (existing) {
|
| 215 |
+
updateBenchmarkResult({
|
| 216 |
+
...existing,
|
| 217 |
+
status: (event.pass as boolean) ? 'pass' : 'fail',
|
| 218 |
+
score: event.score as number,
|
| 219 |
+
reason: event.reason as string,
|
| 220 |
+
sql: event.sql as string,
|
| 221 |
+
attempts: (event.attempts as number) ?? null,
|
| 222 |
+
refRowCount: (event.refRowCount as number) ?? null,
|
| 223 |
+
agentRowCount: (event.agentRowCount as number) ?? null,
|
| 224 |
+
})
|
| 225 |
+
}
|
| 226 |
+
} else if (event.type === 'done') {
|
| 227 |
+
setOverallScore(event.overallScore as number)
|
| 228 |
+
setActiveBenchmarkId(null)
|
| 229 |
+
setIsBenchmarking(false)
|
| 230 |
+
} else if (event.type === 'error') {
|
| 231 |
+
setActiveBenchmarkId(null)
|
| 232 |
+
setIsBenchmarking(false)
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
} catch {
|
| 236 |
+
setIsBenchmarking(false)
|
| 237 |
+
setActiveBenchmarkId(null)
|
| 238 |
+
}
|
| 239 |
+
},
|
| 240 |
+
[isBenchmarking, benchmarkResults, setIsBenchmarking, updateBenchmarkResult,
|
| 241 |
+
setOverallScore, setActiveBenchmarkId, taskDifficulty]
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
const passedCount = benchmarkResults.filter((r) => r.status === 'pass').length
|
| 245 |
+
const completedCount = benchmarkResults.filter((r) => r.status === 'pass' || r.status === 'fail').length
|
| 246 |
+
const totalScore = benchmarkResults.reduce((s, r) => s + (r.score ?? 0), 0)
|
| 247 |
+
const progressPct = benchmarkResults.length > 0 ? Math.round((completedCount / benchmarkResults.length) * 100) : 0
|
| 248 |
+
const scorePct = completedCount > 0 ? Math.round((totalScore / benchmarkResults.length) * 100) : 0
|
| 249 |
+
|
| 250 |
+
return (
|
| 251 |
+
<div className="flex flex-col h-full">
|
| 252 |
+
{/* Header */}
|
| 253 |
+
<div className="px-4 py-3 border-b border-white/[0.06] shrink-0">
|
| 254 |
+
<div className="flex items-center justify-between mb-2">
|
| 255 |
+
<div className="flex items-center gap-2">
|
| 256 |
+
<Target size={14} className="text-violet-400" />
|
| 257 |
+
<span className="text-xs font-semibold text-white">Benchmark</span>
|
| 258 |
+
{completedCount > 0 && (
|
| 259 |
+
<span className="text-xs text-gray-500 font-mono">
|
| 260 |
+
{passedCount}/{benchmarkResults.length} passed
|
| 261 |
+
</span>
|
| 262 |
+
)}
|
| 263 |
+
</div>
|
| 264 |
+
<div className="flex items-center gap-2">
|
| 265 |
+
{completedCount > 0 && (
|
| 266 |
+
<button
|
| 267 |
+
onClick={resetBenchmark}
|
| 268 |
+
disabled={isBenchmarking}
|
| 269 |
+
className="flex items-center gap-1 px-2 py-1 rounded-lg text-[10px] text-gray-500 hover:text-gray-300 hover:bg-white/5 transition-all disabled:opacity-40"
|
| 270 |
+
>
|
| 271 |
+
<RotateCcw size={10} />
|
| 272 |
+
Reset
|
| 273 |
+
</button>
|
| 274 |
+
)}
|
| 275 |
+
<button
|
| 276 |
+
onClick={() => void runBenchmark()}
|
| 277 |
+
disabled={isBenchmarking || !dbSeeded}
|
| 278 |
+
className="flex items-center gap-1.5 px-3 py-1.5 rounded-lg bg-violet-600 hover:bg-violet-500 disabled:opacity-40 disabled:cursor-not-allowed transition-all text-white text-xs font-semibold"
|
| 279 |
+
>
|
| 280 |
+
{isBenchmarking ? (
|
| 281 |
+
<Loader2 size={11} className="animate-spin" />
|
| 282 |
+
) : (
|
| 283 |
+
<Play size={11} />
|
| 284 |
+
)}
|
| 285 |
+
Run All
|
| 286 |
+
</button>
|
| 287 |
+
</div>
|
| 288 |
+
</div>
|
| 289 |
+
|
| 290 |
+
{/* Overall score */}
|
| 291 |
+
{overallScore !== null && (
|
| 292 |
+
<motion.div
|
| 293 |
+
initial={{ opacity: 0, scale: 0.95 }}
|
| 294 |
+
animate={{ opacity: 1, scale: 1 }}
|
| 295 |
+
className="mb-2 flex items-center gap-3 px-3 py-2 rounded-xl border border-violet-500/20 bg-violet-500/5"
|
| 296 |
+
>
|
| 297 |
+
<Zap size={14} className="text-violet-400 shrink-0" />
|
| 298 |
+
<div>
|
| 299 |
+
<div className="text-[10px] text-gray-500 uppercase tracking-wider">Overall Score</div>
|
| 300 |
+
<div className="text-xl font-bold font-mono text-violet-300">
|
| 301 |
+
{(overallScore * 100).toFixed(0)}%
|
| 302 |
+
</div>
|
| 303 |
+
</div>
|
| 304 |
+
</motion.div>
|
| 305 |
+
)}
|
| 306 |
+
|
| 307 |
+
{/* Score bar */}
|
| 308 |
+
{completedCount > 0 && (
|
| 309 |
+
<div className="flex flex-col gap-1">
|
| 310 |
+
<div className="flex items-center justify-between text-[10px]">
|
| 311 |
+
<span className="text-gray-500">
|
| 312 |
+
Score: {totalScore.toFixed(1)}/{benchmarkResults.length}
|
| 313 |
+
</span>
|
| 314 |
+
<span className="text-violet-400 font-mono">{scorePct}%</span>
|
| 315 |
+
</div>
|
| 316 |
+
<div className="h-1.5 bg-white/5 rounded-full overflow-hidden">
|
| 317 |
+
<motion.div
|
| 318 |
+
className="h-full rounded-full bg-gradient-to-r from-violet-600 to-violet-400"
|
| 319 |
+
initial={{ width: 0 }}
|
| 320 |
+
animate={{ width: `${scorePct}%` }}
|
| 321 |
+
transition={{ duration: 0.5, ease: 'easeOut' }}
|
| 322 |
+
/>
|
| 323 |
+
</div>
|
| 324 |
+
</div>
|
| 325 |
+
)}
|
| 326 |
+
|
| 327 |
+
{/* Progress */}
|
| 328 |
+
{isBenchmarking && (
|
| 329 |
+
<div className="mt-1.5">
|
| 330 |
+
<div className="h-1 bg-white/5 rounded-full overflow-hidden">
|
| 331 |
+
<motion.div
|
| 332 |
+
className="h-full rounded-full bg-violet-500/60"
|
| 333 |
+
initial={{ width: 0 }}
|
| 334 |
+
animate={{ width: `${progressPct}%` }}
|
| 335 |
+
transition={{ duration: 0.3 }}
|
| 336 |
+
/>
|
| 337 |
+
</div>
|
| 338 |
+
</div>
|
| 339 |
+
)}
|
| 340 |
+
</div>
|
| 341 |
+
|
| 342 |
+
{/* Difficulty tabs */}
|
| 343 |
+
<div className="flex items-center gap-1 px-4 py-2 border-b border-white/[0.06] shrink-0">
|
| 344 |
+
{DIFFICULTY_TABS.map((tab) => (
|
| 345 |
+
<button
|
| 346 |
+
key={tab.id}
|
| 347 |
+
onClick={() => setTaskDifficulty(tab.id)}
|
| 348 |
+
className={`px-3 py-1 rounded-lg text-xs font-medium transition-all ${
|
| 349 |
+
taskDifficulty === tab.id
|
| 350 |
+
? 'bg-violet-600/20 text-violet-300 border border-violet-500/30'
|
| 351 |
+
: 'text-gray-500 hover:text-gray-300 hover:bg-white/5 border border-transparent'
|
| 352 |
+
}`}
|
| 353 |
+
>
|
| 354 |
+
{tab.label}
|
| 355 |
+
</button>
|
| 356 |
+
))}
|
| 357 |
+
</div>
|
| 358 |
+
|
| 359 |
+
{/* Query list */}
|
| 360 |
+
<div className="flex-1 overflow-y-auto">
|
| 361 |
+
<div className="p-2 flex flex-col gap-1">
|
| 362 |
+
{benchmarkResults.map((result) => (
|
| 363 |
+
<QueryRow
|
| 364 |
+
key={result.id}
|
| 365 |
+
result={result}
|
| 366 |
+
isActive={activeBenchmarkId === result.id}
|
| 367 |
+
isExpanded={expandedIds.has(result.id)}
|
| 368 |
+
onToggleExpand={() => toggleExpand(result.id)}
|
| 369 |
+
onRunSingle={() => void runBenchmark([result.id])}
|
| 370 |
+
isRunning={isBenchmarking}
|
| 371 |
+
dbSeeded={dbSeeded}
|
| 372 |
+
/>
|
| 373 |
+
))}
|
| 374 |
+
</div>
|
| 375 |
+
</div>
|
| 376 |
+
|
| 377 |
+
{!dbSeeded && (
|
| 378 |
+
<div className="px-4 py-2 border-t border-white/[0.06] text-[10px] text-gray-600 text-center shrink-0">
|
| 379 |
+
Waiting for database initialization...
|
| 380 |
+
</div>
|
| 381 |
+
)}
|
| 382 |
+
</div>
|
| 383 |
+
)
|
| 384 |
+
}
|
frontend/src/components/ChatPanel.tsx
ADDED
|
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useState, useRef, useEffect, useCallback } from 'react'
|
| 2 |
+
import { motion, AnimatePresence } from 'framer-motion'
|
| 3 |
+
import {
|
| 4 |
+
Send, CheckCircle2, XCircle, ChevronDown, ChevronUp,
|
| 5 |
+
Loader2, MessageSquare, Zap, RefreshCw, Trash2,
|
| 6 |
+
} from 'lucide-react'
|
| 7 |
+
import { useStore } from '../store/useStore'
|
| 8 |
+
import { streamExecuteQuery, submitFeedback } from '../lib/api'
|
| 9 |
+
import { ResultsTable } from './ResultsTable'
|
| 10 |
+
import type { ChatMessage, AttemptStep } from '../lib/types'
|
| 11 |
+
|
| 12 |
+
// βββ SQL Syntax Highlighter βββββββββββββββββββββββββββββββββββββββ
|
| 13 |
+
|
| 14 |
+
const SQL_KEYWORDS = /\b(SELECT|FROM|WHERE|JOIN|LEFT|RIGHT|INNER|OUTER|FULL|ON|GROUP\s+BY|ORDER\s+BY|HAVING|LIMIT|OFFSET|UNION|ALL|DISTINCT|AS|AND|OR|NOT|IN|IS|NULL|LIKE|BETWEEN|CASE|WHEN|THEN|ELSE|END|WITH|CTE|INSERT|UPDATE|DELETE|CREATE|DROP|ALTER|TABLE|INDEX|VIEW|SET|VALUES|INTO|EXISTS|COUNT|SUM|AVG|MIN|MAX|COALESCE|NULLIF|CAST|OVER|PARTITION\s+BY|ROW_NUMBER|RANK|DENSE_RANK|LAG|LEAD|DATE|STRFTIME|JULIANDAY|ROUND|ABS|LENGTH|SUBSTR|UPPER|LOWER|TRIM|REPLACE|IFNULL)\b/gi
|
| 15 |
+
|
| 16 |
+
function SqlBlock({ sql, streaming }: { sql: string; streaming?: boolean }) {
|
| 17 |
+
const parts: React.ReactNode[] = []
|
| 18 |
+
let last = 0
|
| 19 |
+
let match: RegExpExecArray | null
|
| 20 |
+
|
| 21 |
+
const re = new RegExp(SQL_KEYWORDS.source, 'gi')
|
| 22 |
+
while ((match = re.exec(sql)) !== null) {
|
| 23 |
+
if (match.index > last) {
|
| 24 |
+
parts.push(<span key={`t-${last}`}>{sql.slice(last, match.index)}</span>)
|
| 25 |
+
}
|
| 26 |
+
parts.push(
|
| 27 |
+
<span key={`k-${match.index}`} className="sql-keyword">
|
| 28 |
+
{match[0]}
|
| 29 |
+
</span>
|
| 30 |
+
)
|
| 31 |
+
last = match.index + match[0].length
|
| 32 |
+
}
|
| 33 |
+
if (last < sql.length) {
|
| 34 |
+
parts.push(<span key={`t-end`}>{sql.slice(last)}</span>)
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
return (
|
| 38 |
+
<pre
|
| 39 |
+
className="px-3 py-2.5 text-xs font-mono bg-violet-950/20 whitespace-pre-wrap overflow-x-auto leading-relaxed border-t border-white/[0.04]"
|
| 40 |
+
style={{ color: 'rgba(221, 214, 254, 0.8)' }}
|
| 41 |
+
>
|
| 42 |
+
{parts}
|
| 43 |
+
{streaming && <span className="cursor-blink" />}
|
| 44 |
+
</pre>
|
| 45 |
+
)
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// βββ Attempt badge ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
+
|
| 50 |
+
function AttemptBadge({ attempt, total }: { attempt: number; total: number }) {
|
| 51 |
+
const colors =
|
| 52 |
+
attempt === 1
|
| 53 |
+
? 'text-gray-400 bg-white/5 border-white/10'
|
| 54 |
+
: attempt === 2
|
| 55 |
+
? 'text-amber-400 bg-amber-500/10 border-amber-500/20'
|
| 56 |
+
: attempt === 3
|
| 57 |
+
? 'text-orange-400 bg-orange-500/10 border-orange-500/20'
|
| 58 |
+
: 'text-red-400 bg-red-500/10 border-red-500/20'
|
| 59 |
+
|
| 60 |
+
return (
|
| 61 |
+
<span className={`text-[10px] font-semibold px-2 py-0.5 rounded-full border ${colors}`}>
|
| 62 |
+
Attempt {attempt}/{total}
|
| 63 |
+
</span>
|
| 64 |
+
)
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// βββ RL Action badge ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 68 |
+
|
| 69 |
+
function RLActionBadge({ action, score }: { action: string; score?: number }) {
|
| 70 |
+
return (
|
| 71 |
+
<span className="inline-flex items-center gap-1 text-[10px] font-semibold px-2 py-0.5 rounded-full border border-orange-500/30 bg-orange-500/10 text-orange-400">
|
| 72 |
+
<Zap size={9} />
|
| 73 |
+
{action}
|
| 74 |
+
{score !== undefined && (
|
| 75 |
+
<span className="text-orange-400/60 ml-0.5">{score.toFixed(2)}</span>
|
| 76 |
+
)}
|
| 77 |
+
</span>
|
| 78 |
+
)
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// βββ Reward display ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
|
| 83 |
+
function RewardBadge({ reward }: { reward: number }) {
|
| 84 |
+
const positive = reward >= 0
|
| 85 |
+
return (
|
| 86 |
+
<motion.span
|
| 87 |
+
initial={{ scale: 0.8, opacity: 0 }}
|
| 88 |
+
animate={{ scale: 1, opacity: 1 }}
|
| 89 |
+
transition={{ type: 'spring', stiffness: 300 }}
|
| 90 |
+
className={`inline-flex items-center gap-0.5 text-[11px] font-bold tabular-nums reward-pulse ${
|
| 91 |
+
positive ? 'text-green-400' : 'text-red-400'
|
| 92 |
+
}`}
|
| 93 |
+
>
|
| 94 |
+
{positive ? '+' : ''}{reward.toFixed(2)}
|
| 95 |
+
</motion.span>
|
| 96 |
+
)
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
// βββ Attempt steps collapsible ββββββββββββββββββββββββββββββββββββ
|
| 100 |
+
|
| 101 |
+
function AttemptSteps({ steps }: { steps: AttemptStep[] }) {
|
| 102 |
+
const [open, setOpen] = useState(false)
|
| 103 |
+
if (steps.length <= 1) return null
|
| 104 |
+
|
| 105 |
+
return (
|
| 106 |
+
<div className="border border-white/[0.05] rounded-xl overflow-hidden">
|
| 107 |
+
<button
|
| 108 |
+
onClick={() => setOpen((v) => !v)}
|
| 109 |
+
className="w-full flex items-center justify-between px-3 py-2 bg-white/[0.02] hover:bg-white/[0.04] transition-colors text-[10px] text-gray-500"
|
| 110 |
+
>
|
| 111 |
+
<span>{steps.length} attempts to solve</span>
|
| 112 |
+
{open ? <ChevronUp size={11} /> : <ChevronDown size={11} />}
|
| 113 |
+
</button>
|
| 114 |
+
<AnimatePresence>
|
| 115 |
+
{open && (
|
| 116 |
+
<motion.div
|
| 117 |
+
initial={{ height: 0, opacity: 0 }}
|
| 118 |
+
animate={{ height: 'auto', opacity: 1 }}
|
| 119 |
+
exit={{ height: 0, opacity: 0 }}
|
| 120 |
+
transition={{ duration: 0.15 }}
|
| 121 |
+
className="overflow-hidden"
|
| 122 |
+
>
|
| 123 |
+
<div className="flex flex-col divide-y divide-white/[0.04]">
|
| 124 |
+
{steps.map((step) => (
|
| 125 |
+
<div key={step.attempt} className="px-3 py-2">
|
| 126 |
+
<div className="flex items-center gap-2 mb-1.5">
|
| 127 |
+
<AttemptBadge attempt={step.attempt} total={steps.length} />
|
| 128 |
+
{step.action && (
|
| 129 |
+
<RLActionBadge action={step.action} score={step.actionScore} />
|
| 130 |
+
)}
|
| 131 |
+
{step.reward !== undefined && <RewardBadge reward={step.reward} />}
|
| 132 |
+
</div>
|
| 133 |
+
{step.error && (
|
| 134 |
+
<div className="text-[10px] text-red-400/70 mb-1 bg-red-500/5 rounded px-2 py-1 border border-red-500/15">
|
| 135 |
+
{step.error}
|
| 136 |
+
</div>
|
| 137 |
+
)}
|
| 138 |
+
<SqlBlock sql={step.sql} />
|
| 139 |
+
</div>
|
| 140 |
+
))}
|
| 141 |
+
</div>
|
| 142 |
+
</motion.div>
|
| 143 |
+
)}
|
| 144 |
+
</AnimatePresence>
|
| 145 |
+
</div>
|
| 146 |
+
)
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// βββ Suggested query chips ββββββββββββββββββββββββββββββββββββββββ
|
| 150 |
+
|
| 151 |
+
const SUGGESTED: Record<string, string[]> = {
|
| 152 |
+
easy: ['Show all products', 'List users from USA', 'What categories exist?'],
|
| 153 |
+
medium: ['Top 5 sellers by revenue', 'Average order value by country', 'Products with low stock'],
|
| 154 |
+
hard: ['Rolling 7-day revenue', 'Seller ranking with rank change', 'Cohort retention analysis'],
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
function EmptyState({ onSelect }: { onSelect: (q: string) => void }) {
|
| 158 |
+
const { taskDifficulty } = useStore()
|
| 159 |
+
const suggestions = SUGGESTED[taskDifficulty] ?? SUGGESTED.easy
|
| 160 |
+
|
| 161 |
+
return (
|
| 162 |
+
<div className="flex flex-col items-center justify-center h-full gap-6 px-8 text-center">
|
| 163 |
+
<div>
|
| 164 |
+
<div
|
| 165 |
+
className="w-12 h-12 rounded-2xl flex items-center justify-center mx-auto mb-4"
|
| 166 |
+
style={{ background: '#1e3a5f', boxShadow: '0 8px 24px rgba(30,58,95,0.4)' }}
|
| 167 |
+
>
|
| 168 |
+
<MessageSquare size={22} className="text-white" />
|
| 169 |
+
</div>
|
| 170 |
+
<h2 className="text-base font-semibold text-white mb-1">Ask about your data</h2>
|
| 171 |
+
<p className="text-xs text-gray-500 max-w-xs">
|
| 172 |
+
Type a question in natural language. The agent will generate SQL, execute it,
|
| 173 |
+
and self-repair on errors using reinforcement learning.
|
| 174 |
+
</p>
|
| 175 |
+
</div>
|
| 176 |
+
|
| 177 |
+
<div className="flex flex-col gap-2 w-full max-w-sm">
|
| 178 |
+
<div className="text-[10px] text-gray-600 uppercase tracking-wider mb-0.5">
|
| 179 |
+
Try these queries
|
| 180 |
+
</div>
|
| 181 |
+
{suggestions.map((q) => (
|
| 182 |
+
<button
|
| 183 |
+
key={q}
|
| 184 |
+
onClick={() => onSelect(q)}
|
| 185 |
+
className="flex items-center gap-2 px-3 py-2.5 rounded-xl border border-white/[0.06] bg-white/[0.02] hover:bg-white/[0.05] hover:border-violet-500/30 transition-all text-left group"
|
| 186 |
+
>
|
| 187 |
+
<span className="text-violet-500 shrink-0 group-hover:text-violet-400">βΊ</span>
|
| 188 |
+
<span className="text-xs text-gray-300">{q}</span>
|
| 189 |
+
</button>
|
| 190 |
+
))}
|
| 191 |
+
</div>
|
| 192 |
+
</div>
|
| 193 |
+
)
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
// βββ Message Card βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 197 |
+
|
| 198 |
+
function MessageCard({
|
| 199 |
+
msg,
|
| 200 |
+
onFeedback,
|
| 201 |
+
onRetry,
|
| 202 |
+
}: {
|
| 203 |
+
msg: ChatMessage
|
| 204 |
+
onFeedback: (id: string, correct: boolean) => Promise<void>
|
| 205 |
+
onRetry: (q: string) => void
|
| 206 |
+
}) {
|
| 207 |
+
const [sqlOpen, setSqlOpen] = useState(true)
|
| 208 |
+
|
| 209 |
+
return (
|
| 210 |
+
<div className="flex flex-col gap-2.5">
|
| 211 |
+
{/* User question bubble */}
|
| 212 |
+
<div className="flex justify-end">
|
| 213 |
+
<div className="max-w-[80%] bg-violet-600/20 border border-violet-500/25 rounded-2xl rounded-tr-sm px-4 py-2.5">
|
| 214 |
+
<p className="text-sm text-white leading-relaxed">{msg.question}</p>
|
| 215 |
+
</div>
|
| 216 |
+
</div>
|
| 217 |
+
|
| 218 |
+
{/* Agent response */}
|
| 219 |
+
<div className="flex flex-col gap-2">
|
| 220 |
+
{/* Streaming thinking */}
|
| 221 |
+
{msg.status === 'streaming' && !msg.sql && (
|
| 222 |
+
<div className="flex items-center gap-2 text-xs text-gray-500 px-1">
|
| 223 |
+
<Loader2 size={11} className="animate-spin text-violet-400" />
|
| 224 |
+
Generating SQL...
|
| 225 |
+
</div>
|
| 226 |
+
)}
|
| 227 |
+
|
| 228 |
+
{/* Multiple attempts */}
|
| 229 |
+
<AttemptSteps steps={msg.steps} />
|
| 230 |
+
|
| 231 |
+
{/* Final SQL block */}
|
| 232 |
+
{msg.sql && (
|
| 233 |
+
<div className="border border-white/[0.06] rounded-xl overflow-hidden">
|
| 234 |
+
<button
|
| 235 |
+
onClick={() => setSqlOpen((v) => !v)}
|
| 236 |
+
className="w-full flex items-center justify-between px-3 py-2 bg-white/[0.02] hover:bg-white/[0.04] transition-colors"
|
| 237 |
+
>
|
| 238 |
+
<div className="flex items-center gap-2">
|
| 239 |
+
<span className="text-[10px] font-semibold text-gray-500 uppercase tracking-wider">
|
| 240 |
+
SQL
|
| 241 |
+
</span>
|
| 242 |
+
{msg.status === 'streaming' && (
|
| 243 |
+
<Loader2 size={10} className="animate-spin text-violet-400" />
|
| 244 |
+
)}
|
| 245 |
+
{msg.attempts > 1 && (
|
| 246 |
+
<AttemptBadge attempt={msg.attempts} total={msg.attempts} />
|
| 247 |
+
)}
|
| 248 |
+
</div>
|
| 249 |
+
{sqlOpen ? (
|
| 250 |
+
<ChevronUp size={11} className="text-gray-600" />
|
| 251 |
+
) : (
|
| 252 |
+
<ChevronDown size={11} className="text-gray-600" />
|
| 253 |
+
)}
|
| 254 |
+
</button>
|
| 255 |
+
{sqlOpen && (
|
| 256 |
+
<SqlBlock sql={msg.sql} streaming={msg.status === 'streaming'} />
|
| 257 |
+
)}
|
| 258 |
+
</div>
|
| 259 |
+
)}
|
| 260 |
+
|
| 261 |
+
{/* Executing indicator */}
|
| 262 |
+
{msg.status === 'streaming' && msg.sql && msg.rows.length === 0 && !msg.errorMsg && (
|
| 263 |
+
<div className="flex items-center gap-2 text-xs text-gray-500 px-1">
|
| 264 |
+
<Loader2 size={11} className="animate-spin text-violet-400" />
|
| 265 |
+
Executing...
|
| 266 |
+
</div>
|
| 267 |
+
)}
|
| 268 |
+
|
| 269 |
+
{/* RL badges row */}
|
| 270 |
+
{(msg.rlAction || msg.reward !== undefined) && (
|
| 271 |
+
<div className="flex items-center gap-2 flex-wrap">
|
| 272 |
+
{msg.rlAction && (
|
| 273 |
+
<RLActionBadge action={msg.rlAction} score={msg.rlActionScore} />
|
| 274 |
+
)}
|
| 275 |
+
{msg.reward !== undefined && <RewardBadge reward={msg.reward} />}
|
| 276 |
+
</div>
|
| 277 |
+
)}
|
| 278 |
+
|
| 279 |
+
{/* Result table */}
|
| 280 |
+
{msg.status === 'done' && msg.attempts > 0 && (
|
| 281 |
+
<div className="flex flex-col gap-1.5">
|
| 282 |
+
<div className="flex items-center gap-2 text-[10px] px-0.5">
|
| 283 |
+
<CheckCircle2 size={11} className="text-green-400" />
|
| 284 |
+
<span className="text-green-400 font-semibold">Success</span>
|
| 285 |
+
<span className="text-gray-600">
|
| 286 |
+
Β· {msg.rowCount} row{msg.rowCount !== 1 ? 's' : ''}
|
| 287 |
+
</span>
|
| 288 |
+
{msg.attempts > 1 && (
|
| 289 |
+
<span className="text-amber-400/60">{msg.attempts} attempts</span>
|
| 290 |
+
)}
|
| 291 |
+
</div>
|
| 292 |
+
<ResultsTable rows={msg.rows} rowCount={msg.rowCount} />
|
| 293 |
+
</div>
|
| 294 |
+
)}
|
| 295 |
+
|
| 296 |
+
{/* Error */}
|
| 297 |
+
{msg.status === 'error' && (
|
| 298 |
+
<div className="flex items-start gap-2 bg-red-500/10 border border-red-500/20 rounded-xl px-3 py-2.5 text-xs text-red-300">
|
| 299 |
+
<XCircle size={12} className="shrink-0 mt-0.5" />
|
| 300 |
+
<div>
|
| 301 |
+
<p className="font-semibold mb-0.5">Query failed</p>
|
| 302 |
+
<p className="opacity-80">{msg.errorMsg ?? 'Agent exhausted all repair attempts'}</p>
|
| 303 |
+
</div>
|
| 304 |
+
</div>
|
| 305 |
+
)}
|
| 306 |
+
|
| 307 |
+
{/* Feedback */}
|
| 308 |
+
{msg.status === 'done' && msg.attempts > 0 && (
|
| 309 |
+
<div className="flex items-center gap-2">
|
| 310 |
+
{msg.feedback ? (
|
| 311 |
+
<div
|
| 312 |
+
className={`text-xs flex items-center gap-1.5 ${
|
| 313 |
+
msg.feedback === 'correct' ? 'text-green-400' : 'text-red-400'
|
| 314 |
+
}`}
|
| 315 |
+
>
|
| 316 |
+
{msg.feedback === 'correct' ? (
|
| 317 |
+
<CheckCircle2 size={12} />
|
| 318 |
+
) : (
|
| 319 |
+
<XCircle size={12} />
|
| 320 |
+
)}
|
| 321 |
+
Marked as {msg.feedback}
|
| 322 |
+
</div>
|
| 323 |
+
) : (
|
| 324 |
+
<>
|
| 325 |
+
<span className="text-[10px] text-gray-600 mr-0.5">Was this correct?</span>
|
| 326 |
+
<button
|
| 327 |
+
disabled={msg.feedbackSending}
|
| 328 |
+
onClick={() => onFeedback(msg.id, true)}
|
| 329 |
+
className="flex items-center gap-1 px-2 py-1 text-[10px] font-medium rounded-lg border border-green-500/25 bg-green-500/8 text-green-400 hover:bg-green-500/15 transition-all disabled:opacity-40"
|
| 330 |
+
>
|
| 331 |
+
<CheckCircle2 size={10} />
|
| 332 |
+
Correct
|
| 333 |
+
</button>
|
| 334 |
+
<button
|
| 335 |
+
disabled={msg.feedbackSending}
|
| 336 |
+
onClick={() => onFeedback(msg.id, false)}
|
| 337 |
+
className="flex items-center gap-1 px-2 py-1 text-[10px] font-medium rounded-lg border border-red-500/25 bg-red-500/8 text-red-400 hover:bg-red-500/15 transition-all disabled:opacity-40"
|
| 338 |
+
>
|
| 339 |
+
<XCircle size={10} />
|
| 340 |
+
Wrong
|
| 341 |
+
</button>
|
| 342 |
+
</>
|
| 343 |
+
)}
|
| 344 |
+
{(msg.status === 'done' || msg.status === 'error') && (
|
| 345 |
+
<button
|
| 346 |
+
onClick={() => onRetry(msg.question)}
|
| 347 |
+
className="ml-auto flex items-center gap-1 text-[10px] text-gray-600 hover:text-gray-400 transition-colors"
|
| 348 |
+
>
|
| 349 |
+
<RefreshCw size={10} />
|
| 350 |
+
Retry
|
| 351 |
+
</button>
|
| 352 |
+
)}
|
| 353 |
+
</div>
|
| 354 |
+
)}
|
| 355 |
+
</div>
|
| 356 |
+
</div>
|
| 357 |
+
)
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
// βββ Chat Panel βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 361 |
+
|
| 362 |
+
export function ChatPanel() {
|
| 363 |
+
const {
|
| 364 |
+
messages, addMessage, updateMessage, clearMessages,
|
| 365 |
+
isExecuting, setIsExecuting,
|
| 366 |
+
taskId, taskDifficulty,
|
| 367 |
+
optimizingBanner, setOptimizingBanner,
|
| 368 |
+
promptGeneration,
|
| 369 |
+
} = useStore()
|
| 370 |
+
|
| 371 |
+
const [input, setInput] = useState('')
|
| 372 |
+
const bottomRef = useRef<HTMLDivElement>(null)
|
| 373 |
+
const inputRef = useRef<HTMLTextAreaElement>(null)
|
| 374 |
+
|
| 375 |
+
useEffect(() => {
|
| 376 |
+
bottomRef.current?.scrollIntoView({ behavior: 'smooth' })
|
| 377 |
+
}, [messages.length])
|
| 378 |
+
|
| 379 |
+
const handleFeedback = useCallback(
|
| 380 |
+
async (id: string, correct: boolean) => {
|
| 381 |
+
const msg = messages.find((m) => m.id === id)
|
| 382 |
+
if (!msg) return
|
| 383 |
+
updateMessage(id, { feedbackSending: true })
|
| 384 |
+
try {
|
| 385 |
+
await submitFeedback(msg.question, msg.sql, correct)
|
| 386 |
+
updateMessage(id, { feedback: correct ? 'correct' : 'wrong', feedbackSending: false })
|
| 387 |
+
} catch {
|
| 388 |
+
updateMessage(id, { feedbackSending: false })
|
| 389 |
+
}
|
| 390 |
+
},
|
| 391 |
+
[messages, updateMessage]
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
const execute = useCallback(
|
| 395 |
+
async (question: string) => {
|
| 396 |
+
if (!question.trim() || isExecuting) return
|
| 397 |
+
setIsExecuting(true)
|
| 398 |
+
|
| 399 |
+
const msgId = `msg-${Date.now()}`
|
| 400 |
+
const newMsg: ChatMessage = {
|
| 401 |
+
id: msgId,
|
| 402 |
+
question,
|
| 403 |
+
status: 'streaming',
|
| 404 |
+
sql: '',
|
| 405 |
+
rows: [],
|
| 406 |
+
rowCount: 0,
|
| 407 |
+
attempts: 0,
|
| 408 |
+
steps: [],
|
| 409 |
+
feedback: null,
|
| 410 |
+
promptGeneration,
|
| 411 |
+
}
|
| 412 |
+
addMessage(newMsg)
|
| 413 |
+
|
| 414 |
+
try {
|
| 415 |
+
for await (const event of streamExecuteQuery(question, taskId)) {
|
| 416 |
+
if (event.type === 'sql') {
|
| 417 |
+
updateMessage(msgId, { sql: event.sql as string })
|
| 418 |
+
} else if (event.type === 'sql_chunk') {
|
| 419 |
+
// incremental SQL streaming β read current sql from store
|
| 420 |
+
const curSql = useStore.getState().messages.find((m) => m.id === msgId)?.sql ?? ''
|
| 421 |
+
updateMessage(msgId, { sql: curSql + (event.chunk as string) })
|
| 422 |
+
} else if (event.type === 'attempt') {
|
| 423 |
+
const step: AttemptStep = {
|
| 424 |
+
attempt: event.attempt as number,
|
| 425 |
+
sql: event.sql as string,
|
| 426 |
+
error: event.error as string | undefined,
|
| 427 |
+
action: event.action as string | undefined,
|
| 428 |
+
actionScore: event.action_score as number | undefined,
|
| 429 |
+
reward: event.reward as number | undefined,
|
| 430 |
+
}
|
| 431 |
+
const curSteps = useStore.getState().messages.find((m) => m.id === msgId)?.steps ?? []
|
| 432 |
+
updateMessage(msgId, {
|
| 433 |
+
attempts: event.attempt as number,
|
| 434 |
+
steps: [...curSteps, step],
|
| 435 |
+
sql: event.sql as string,
|
| 436 |
+
rlAction: event.action as string | undefined,
|
| 437 |
+
rlActionScore: event.action_score as number | undefined,
|
| 438 |
+
})
|
| 439 |
+
} else if (event.type === 'result') {
|
| 440 |
+
updateMessage(msgId, {
|
| 441 |
+
rows: (event.rows as Record<string, unknown>[]) ?? [],
|
| 442 |
+
rowCount: (event.row_count as number) ?? 0,
|
| 443 |
+
reward: event.reward as number | undefined,
|
| 444 |
+
})
|
| 445 |
+
} else if (event.type === 'done') {
|
| 446 |
+
updateMessage(msgId, {
|
| 447 |
+
status: 'done',
|
| 448 |
+
attempts: (event.attempts as number) ?? 1,
|
| 449 |
+
reward: event.reward as number | undefined,
|
| 450 |
+
})
|
| 451 |
+
} else if (event.type === 'error') {
|
| 452 |
+
updateMessage(msgId, {
|
| 453 |
+
status: 'error',
|
| 454 |
+
errorMsg: event.message as string,
|
| 455 |
+
})
|
| 456 |
+
} else if (event.type === 'gepa_start') {
|
| 457 |
+
setOptimizingBanner(true)
|
| 458 |
+
} else if (event.type === 'gepa_done') {
|
| 459 |
+
setOptimizingBanner(false)
|
| 460 |
+
}
|
| 461 |
+
}
|
| 462 |
+
} catch (err) {
|
| 463 |
+
updateMessage(msgId, {
|
| 464 |
+
status: 'error',
|
| 465 |
+
errorMsg: err instanceof Error ? err.message : 'Network error',
|
| 466 |
+
})
|
| 467 |
+
} finally {
|
| 468 |
+
setIsExecuting(false)
|
| 469 |
+
// If still streaming after generator ends, mark done
|
| 470 |
+
const finalMsg = useStore.getState().messages.find((m) => m.id === msgId)
|
| 471 |
+
if (finalMsg?.status === 'streaming') {
|
| 472 |
+
updateMessage(msgId, { status: finalMsg.sql ? 'done' : 'error' })
|
| 473 |
+
}
|
| 474 |
+
}
|
| 475 |
+
},
|
| 476 |
+
[isExecuting, setIsExecuting, addMessage, updateMessage, taskId, promptGeneration, setOptimizingBanner]
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
const handleSend = () => {
|
| 480 |
+
if (!input.trim()) return
|
| 481 |
+
const q = input.trim()
|
| 482 |
+
setInput('')
|
| 483 |
+
void execute(q)
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
| 487 |
+
if (e.key === 'Enter' && !e.shiftKey) {
|
| 488 |
+
e.preventDefault()
|
| 489 |
+
handleSend()
|
| 490 |
+
}
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
const suggestions = SUGGESTED[taskDifficulty] ?? SUGGESTED.easy
|
| 494 |
+
|
| 495 |
+
return (
|
| 496 |
+
<div className="flex flex-col h-full">
|
| 497 |
+
{/* Optimizing banner */}
|
| 498 |
+
<AnimatePresence>
|
| 499 |
+
{optimizingBanner && (
|
| 500 |
+
<motion.div
|
| 501 |
+
initial={{ height: 0, opacity: 0 }}
|
| 502 |
+
animate={{ height: 'auto', opacity: 1 }}
|
| 503 |
+
exit={{ height: 0, opacity: 0 }}
|
| 504 |
+
className="shrink-0 overflow-hidden"
|
| 505 |
+
>
|
| 506 |
+
<div className="shimmer-banner border-b border-violet-500/20 px-4 py-2 flex items-center gap-2">
|
| 507 |
+
<Loader2 size={12} className="animate-spin text-violet-400" />
|
| 508 |
+
<span className="text-xs text-violet-300 font-semibold">
|
| 509 |
+
Optimizing system prompt via GEPA...
|
| 510 |
+
</span>
|
| 511 |
+
</div>
|
| 512 |
+
</motion.div>
|
| 513 |
+
)}
|
| 514 |
+
</AnimatePresence>
|
| 515 |
+
|
| 516 |
+
{/* Messages */}
|
| 517 |
+
<div className="flex-1 overflow-y-auto px-4 py-4">
|
| 518 |
+
{messages.length === 0 ? (
|
| 519 |
+
<EmptyState onSelect={(q) => { setInput(q); inputRef.current?.focus() }} />
|
| 520 |
+
) : (
|
| 521 |
+
<div className="flex flex-col gap-6 max-w-3xl mx-auto">
|
| 522 |
+
{messages.map((msg) => (
|
| 523 |
+
<MessageCard
|
| 524 |
+
key={msg.id}
|
| 525 |
+
msg={msg}
|
| 526 |
+
onFeedback={handleFeedback}
|
| 527 |
+
onRetry={(q) => { setInput(q); inputRef.current?.focus() }}
|
| 528 |
+
/>
|
| 529 |
+
))}
|
| 530 |
+
<div ref={bottomRef} />
|
| 531 |
+
</div>
|
| 532 |
+
)}
|
| 533 |
+
</div>
|
| 534 |
+
|
| 535 |
+
{/* Input area */}
|
| 536 |
+
<div
|
| 537 |
+
className="shrink-0 border-t border-white/[0.06] px-4 py-3"
|
| 538 |
+
style={{ background: 'var(--bg-secondary)' }}
|
| 539 |
+
>
|
| 540 |
+
{/* Suggested chips */}
|
| 541 |
+
{messages.length > 0 && (
|
| 542 |
+
<div className="flex gap-1.5 flex-wrap mb-2.5">
|
| 543 |
+
{suggestions.slice(0, 3).map((q) => (
|
| 544 |
+
<button
|
| 545 |
+
key={q}
|
| 546 |
+
onClick={() => { setInput(q); inputRef.current?.focus() }}
|
| 547 |
+
className="text-[10px] px-2.5 py-1 rounded-full border border-white/[0.06] text-gray-500 hover:text-gray-300 hover:border-violet-500/30 transition-all"
|
| 548 |
+
>
|
| 549 |
+
{q}
|
| 550 |
+
</button>
|
| 551 |
+
))}
|
| 552 |
+
</div>
|
| 553 |
+
)}
|
| 554 |
+
|
| 555 |
+
<div className="flex items-end gap-2">
|
| 556 |
+
<div className="flex-1 relative">
|
| 557 |
+
<textarea
|
| 558 |
+
ref={inputRef}
|
| 559 |
+
value={input}
|
| 560 |
+
onChange={(e) => setInput(e.target.value)}
|
| 561 |
+
onKeyDown={handleKeyDown}
|
| 562 |
+
placeholder="Ask about products, orders, sellers..."
|
| 563 |
+
disabled={isExecuting}
|
| 564 |
+
rows={1}
|
| 565 |
+
className="w-full px-3 py-2.5 pr-10 text-sm text-white rounded-xl border border-white/[0.06] bg-white/[0.03] placeholder-gray-600 resize-none focus:outline-none focus:border-violet-500/40 focus:bg-white/[0.05] transition-all disabled:opacity-50"
|
| 566 |
+
style={{ minHeight: 40, maxHeight: 120, overflowY: 'auto' }}
|
| 567 |
+
/>
|
| 568 |
+
</div>
|
| 569 |
+
<div className="flex flex-col gap-1.5 shrink-0">
|
| 570 |
+
<button
|
| 571 |
+
onClick={handleSend}
|
| 572 |
+
disabled={!input.trim() || isExecuting}
|
| 573 |
+
className="w-9 h-9 rounded-xl bg-violet-600 hover:bg-violet-500 disabled:opacity-40 disabled:cursor-not-allowed transition-all flex items-center justify-center"
|
| 574 |
+
>
|
| 575 |
+
{isExecuting ? (
|
| 576 |
+
<Loader2 size={14} className="animate-spin text-white" />
|
| 577 |
+
) : (
|
| 578 |
+
<Send size={14} className="text-white" />
|
| 579 |
+
)}
|
| 580 |
+
</button>
|
| 581 |
+
{messages.length > 0 && (
|
| 582 |
+
<button
|
| 583 |
+
onClick={clearMessages}
|
| 584 |
+
disabled={isExecuting}
|
| 585 |
+
className="w-9 h-9 rounded-xl border border-white/[0.06] hover:bg-white/5 disabled:opacity-40 transition-all flex items-center justify-center text-gray-600 hover:text-gray-400"
|
| 586 |
+
title="Clear chat"
|
| 587 |
+
>
|
| 588 |
+
<Trash2 size={12} />
|
| 589 |
+
</button>
|
| 590 |
+
)}
|
| 591 |
+
</div>
|
| 592 |
+
</div>
|
| 593 |
+
<p className="text-[9px] text-gray-700 mt-1.5 text-center">
|
| 594 |
+
Enter to send Β· Shift+Enter for newline Β· Agent uses LinUCB + GEPA
|
| 595 |
+
</p>
|
| 596 |
+
</div>
|
| 597 |
+
</div>
|
| 598 |
+
)
|
| 599 |
+
}
|
frontend/src/components/ERDiagram.tsx
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useState, useEffect, useRef } from 'react'
|
| 2 |
+
import { Loader2, GitFork } from 'lucide-react'
|
| 3 |
+
import { useStore } from '../store/useStore'
|
| 4 |
+
import { fetchSchemaGraph } from '../lib/api'
|
| 5 |
+
import type { SchemaTable, SchemaRelationship } from '../lib/types'
|
| 6 |
+
|
| 7 |
+
// βββ Table card βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 8 |
+
|
| 9 |
+
function TableCard({ table, x, y }: { table: SchemaTable; x: number; y: number }) {
|
| 10 |
+
return (
|
| 11 |
+
<g transform={`translate(${x},${y})`}>
|
| 12 |
+
{/* Card bg */}
|
| 13 |
+
<rect
|
| 14 |
+
width={180}
|
| 15 |
+
height={28 + table.columns.length * 20}
|
| 16 |
+
rx={8}
|
| 17 |
+
fill="#0e0e16"
|
| 18 |
+
stroke="rgba(255,255,255,0.08)"
|
| 19 |
+
strokeWidth={1}
|
| 20 |
+
/>
|
| 21 |
+
{/* Header */}
|
| 22 |
+
<rect width={180} height={28} rx={8} fill="rgba(139,92,246,0.15)" />
|
| 23 |
+
<rect y={20} width={180} height={8} fill="rgba(139,92,246,0.15)" />
|
| 24 |
+
<text
|
| 25 |
+
x={10}
|
| 26 |
+
y={18}
|
| 27 |
+
fill="#a78bfa"
|
| 28 |
+
fontSize={11}
|
| 29 |
+
fontWeight="bold"
|
| 30 |
+
fontFamily="ui-monospace,monospace"
|
| 31 |
+
>
|
| 32 |
+
{table.name}
|
| 33 |
+
</text>
|
| 34 |
+
|
| 35 |
+
{/* Columns */}
|
| 36 |
+
{table.columns.map((col, i) => (
|
| 37 |
+
<g key={col.name} transform={`translate(0,${28 + i * 20})`}>
|
| 38 |
+
<rect
|
| 39 |
+
width={180}
|
| 40 |
+
height={20}
|
| 41 |
+
fill={i % 2 === 0 ? 'rgba(255,255,255,0.01)' : 'transparent'}
|
| 42 |
+
/>
|
| 43 |
+
<text
|
| 44 |
+
x={10}
|
| 45 |
+
y={14}
|
| 46 |
+
fill={col.pk ? '#60a5fa' : col.fk ? '#34d399' : 'rgba(255,255,255,0.5)'}
|
| 47 |
+
fontSize={10}
|
| 48 |
+
fontFamily="ui-monospace,monospace"
|
| 49 |
+
>
|
| 50 |
+
{col.pk ? 'π ' : col.fk ? 'π ' : ' '}
|
| 51 |
+
{col.name}
|
| 52 |
+
</text>
|
| 53 |
+
<text
|
| 54 |
+
x={170}
|
| 55 |
+
y={14}
|
| 56 |
+
fill="rgba(255,255,255,0.2)"
|
| 57 |
+
fontSize={9}
|
| 58 |
+
fontFamily="ui-monospace,monospace"
|
| 59 |
+
textAnchor="end"
|
| 60 |
+
>
|
| 61 |
+
{col.type}
|
| 62 |
+
</text>
|
| 63 |
+
</g>
|
| 64 |
+
))}
|
| 65 |
+
</g>
|
| 66 |
+
)
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
// βββ Layout helpers βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
|
| 71 |
+
function layoutTables(tables: SchemaTable[]) {
|
| 72 |
+
const CARD_W = 180
|
| 73 |
+
const CARD_H_BASE = 28
|
| 74 |
+
const COL_H = 20
|
| 75 |
+
const GAP_X = 40
|
| 76 |
+
const GAP_Y = 30
|
| 77 |
+
const COLS_PER_ROW = 3
|
| 78 |
+
|
| 79 |
+
const positions: Record<string, { x: number; y: number; w: number; h: number }> = {}
|
| 80 |
+
let maxRowH = 0
|
| 81 |
+
|
| 82 |
+
tables.forEach((t, i) => {
|
| 83 |
+
const col = i % COLS_PER_ROW
|
| 84 |
+
const row = Math.floor(i / COLS_PER_ROW)
|
| 85 |
+
const h = CARD_H_BASE + t.columns.length * COL_H
|
| 86 |
+
|
| 87 |
+
if (row === Math.floor(i / COLS_PER_ROW) && col === 0) maxRowH = 0
|
| 88 |
+
maxRowH = Math.max(maxRowH, h)
|
| 89 |
+
|
| 90 |
+
const prevRowsH = tables
|
| 91 |
+
.slice(0, row * COLS_PER_ROW)
|
| 92 |
+
.reduce((acc, _, idx) => {
|
| 93 |
+
if (idx % COLS_PER_ROW === 0) {
|
| 94 |
+
const rowH = tables.slice(idx, idx + COLS_PER_ROW).reduce(
|
| 95 |
+
(m, rt) => Math.max(m, CARD_H_BASE + rt.columns.length * COL_H),
|
| 96 |
+
0
|
| 97 |
+
)
|
| 98 |
+
return acc + rowH + GAP_Y
|
| 99 |
+
}
|
| 100 |
+
return acc
|
| 101 |
+
}, 0)
|
| 102 |
+
|
| 103 |
+
positions[t.name] = {
|
| 104 |
+
x: col * (CARD_W + GAP_X) + 20,
|
| 105 |
+
y: prevRowsH + 20,
|
| 106 |
+
w: CARD_W,
|
| 107 |
+
h,
|
| 108 |
+
}
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
return positions
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
function RelationshipLine({
|
| 115 |
+
from,
|
| 116 |
+
to,
|
| 117 |
+
positions,
|
| 118 |
+
}: {
|
| 119 |
+
from: string
|
| 120 |
+
to: string
|
| 121 |
+
positions: Record<string, { x: number; y: number; w: number; h: number }>
|
| 122 |
+
}) {
|
| 123 |
+
const a = positions[from]
|
| 124 |
+
const b = positions[to]
|
| 125 |
+
if (!a || !b) return null
|
| 126 |
+
|
| 127 |
+
const x1 = a.x + a.w
|
| 128 |
+
const y1 = a.y + 14
|
| 129 |
+
const x2 = b.x
|
| 130 |
+
const y2 = b.y + 14
|
| 131 |
+
const cx = (x1 + x2) / 2
|
| 132 |
+
|
| 133 |
+
return (
|
| 134 |
+
<path
|
| 135 |
+
d={`M${x1},${y1} C${cx},${y1} ${cx},${y2} ${x2},${y2}`}
|
| 136 |
+
stroke="rgba(139,92,246,0.3)"
|
| 137 |
+
strokeWidth={1.5}
|
| 138 |
+
fill="none"
|
| 139 |
+
strokeDasharray="4 3"
|
| 140 |
+
/>
|
| 141 |
+
)
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
// βββ ER Diagram component βββββββββββββββββββββββββββββββββββββββββ
|
| 145 |
+
|
| 146 |
+
export function ERDiagram() {
|
| 147 |
+
const { schemaGraph, setSchemaGraph } = useStore()
|
| 148 |
+
const [loading, setLoading] = useState(false)
|
| 149 |
+
const svgRef = useRef<SVGSVGElement>(null)
|
| 150 |
+
|
| 151 |
+
const load = async () => {
|
| 152 |
+
setLoading(true)
|
| 153 |
+
try {
|
| 154 |
+
const data = await fetchSchemaGraph()
|
| 155 |
+
setSchemaGraph(data)
|
| 156 |
+
} catch {
|
| 157 |
+
// noop
|
| 158 |
+
} finally {
|
| 159 |
+
setLoading(false)
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
useEffect(() => {
|
| 164 |
+
if (!schemaGraph) void load()
|
| 165 |
+
// eslint-disable-next-line react-hooks/exhaustive-deps
|
| 166 |
+
}, [])
|
| 167 |
+
|
| 168 |
+
if (loading) {
|
| 169 |
+
return (
|
| 170 |
+
<div className="flex items-center justify-center h-full gap-2 text-gray-500">
|
| 171 |
+
<Loader2 size={16} className="animate-spin" />
|
| 172 |
+
<span className="text-sm">Loading schema...</span>
|
| 173 |
+
</div>
|
| 174 |
+
)
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
if (!schemaGraph || schemaGraph.tables.length === 0) {
|
| 178 |
+
return (
|
| 179 |
+
<div className="flex flex-col items-center justify-center h-full gap-3 text-gray-600">
|
| 180 |
+
<GitFork size={32} className="text-gray-700" />
|
| 181 |
+
<p className="text-sm">Schema will appear after database connects</p>
|
| 182 |
+
<button
|
| 183 |
+
onClick={() => void load()}
|
| 184 |
+
className="text-xs text-violet-400 hover:text-violet-300 transition-colors"
|
| 185 |
+
>
|
| 186 |
+
Retry
|
| 187 |
+
</button>
|
| 188 |
+
</div>
|
| 189 |
+
)
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
const { tables, relationships } = schemaGraph
|
| 193 |
+
const positions = layoutTables(tables)
|
| 194 |
+
|
| 195 |
+
const allX = Object.values(positions).map((p) => p.x + p.w)
|
| 196 |
+
const allY = Object.values(positions).map((p) => p.y + p.h)
|
| 197 |
+
const svgW = Math.max(...allX) + 40
|
| 198 |
+
const svgH = Math.max(...allY) + 40
|
| 199 |
+
|
| 200 |
+
return (
|
| 201 |
+
<div className="h-full overflow-auto p-4">
|
| 202 |
+
<div className="text-[10px] text-gray-500 uppercase tracking-widest mb-3 flex items-center gap-1.5">
|
| 203 |
+
<GitFork size={10} className="text-violet-400" />
|
| 204 |
+
Entity Relationship Diagram
|
| 205 |
+
<span className="text-gray-700">Β· {tables.length} tables</span>
|
| 206 |
+
</div>
|
| 207 |
+
<svg
|
| 208 |
+
ref={svgRef}
|
| 209 |
+
width={svgW}
|
| 210 |
+
height={svgH}
|
| 211 |
+
style={{ minWidth: svgW }}
|
| 212 |
+
>
|
| 213 |
+
{/* FK lines */}
|
| 214 |
+
{(relationships as SchemaRelationship[]).map((rel, i) => (
|
| 215 |
+
<RelationshipLine
|
| 216 |
+
key={i}
|
| 217 |
+
from={rel.from}
|
| 218 |
+
to={rel.to}
|
| 219 |
+
positions={positions}
|
| 220 |
+
/>
|
| 221 |
+
))}
|
| 222 |
+
{/* Tables */}
|
| 223 |
+
{tables.map((t: SchemaTable) => (
|
| 224 |
+
<TableCard
|
| 225 |
+
key={t.name}
|
| 226 |
+
table={t}
|
| 227 |
+
x={positions[t.name]?.x ?? 0}
|
| 228 |
+
y={positions[t.name]?.y ?? 0}
|
| 229 |
+
/>
|
| 230 |
+
))}
|
| 231 |
+
</svg>
|
| 232 |
+
</div>
|
| 233 |
+
)
|
| 234 |
+
}
|
frontend/src/components/Header.tsx
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { Database, Sun, Moon, PanelLeftOpen, PanelRightOpen, Cpu } from 'lucide-react'
|
| 2 |
+
import { useStore } from '../store/useStore'
|
| 3 |
+
import type { Difficulty } from '../lib/types'
|
| 4 |
+
|
| 5 |
+
interface HeaderProps {
|
| 6 |
+
onToggleLeft: () => void
|
| 7 |
+
onToggleRight: () => void
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
const DIFFICULTIES: { id: Difficulty; label: string; color: string }[] = [
|
| 11 |
+
{ id: 'easy', label: 'Easy', color: 'text-green-400 border-green-500/30 bg-green-500/10' },
|
| 12 |
+
{ id: 'medium', label: 'Medium', color: 'text-amber-400 border-amber-500/30 bg-amber-500/10' },
|
| 13 |
+
{ id: 'hard', label: 'Hard', color: 'text-red-400 border-red-500/30 bg-red-500/10' },
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
export function Header({ onToggleLeft, onToggleRight }: HeaderProps) {
|
| 17 |
+
const { theme, toggleTheme, dbSeeded, taskDifficulty, setTaskDifficulty } = useStore()
|
| 18 |
+
|
| 19 |
+
return (
|
| 20 |
+
<header
|
| 21 |
+
className="border-b px-3 sm:px-5 py-3 flex items-center justify-between shrink-0 backdrop-blur-sm sticky top-0 z-50 theme-border"
|
| 22 |
+
style={{ background: 'var(--bg-secondary)' }}
|
| 23 |
+
>
|
| 24 |
+
<div className="flex items-center gap-2 sm:gap-3">
|
| 25 |
+
{/* Mobile sidebar toggle */}
|
| 26 |
+
<button
|
| 27 |
+
onClick={onToggleLeft}
|
| 28 |
+
className="lg:hidden flex items-center gap-1 px-2 py-1.5 rounded-lg hover:bg-white/5 text-gray-400 hover:text-white transition-colors text-[10px]"
|
| 29 |
+
>
|
| 30 |
+
<PanelLeftOpen size={14} />
|
| 31 |
+
<span className="hidden sm:inline">Data</span>
|
| 32 |
+
</button>
|
| 33 |
+
|
| 34 |
+
{/* Logo */}
|
| 35 |
+
<div
|
| 36 |
+
className="w-7 h-7 rounded-lg flex items-center justify-center shadow-lg shrink-0"
|
| 37 |
+
style={{ background: '#1e3a5f', boxShadow: '0 4px 12px rgba(30,58,95,0.4)' }}
|
| 38 |
+
>
|
| 39 |
+
<Database size={13} className="text-white" />
|
| 40 |
+
</div>
|
| 41 |
+
|
| 42 |
+
{/* Title */}
|
| 43 |
+
<div>
|
| 44 |
+
<h1 className="text-sm font-bold text-white tracking-tight leading-none">
|
| 45 |
+
SQL Agent OpenEnv
|
| 46 |
+
</h1>
|
| 47 |
+
<p className="text-[10px] text-gray-600 hidden sm:block mt-0.5">
|
| 48 |
+
Reinforcement Learning Environment
|
| 49 |
+
</p>
|
| 50 |
+
</div>
|
| 51 |
+
</div>
|
| 52 |
+
|
| 53 |
+
<div className="flex items-center gap-2 sm:gap-3">
|
| 54 |
+
{/* Connection status */}
|
| 55 |
+
{dbSeeded ? (
|
| 56 |
+
<div className="hidden sm:flex items-center gap-1.5 text-[10px] text-green-400">
|
| 57 |
+
<span className="w-1.5 h-1.5 rounded-full bg-green-400 inline-block" />
|
| 58 |
+
benchmark db
|
| 59 |
+
</div>
|
| 60 |
+
) : (
|
| 61 |
+
<div className="hidden sm:flex items-center gap-1.5 text-[10px] text-amber-400">
|
| 62 |
+
<span className="w-1.5 h-1.5 rounded-full bg-amber-400 inline-block animate-pulse" />
|
| 63 |
+
connecting...
|
| 64 |
+
</div>
|
| 65 |
+
)}
|
| 66 |
+
|
| 67 |
+
{/* RL indicator */}
|
| 68 |
+
<div className="hidden md:flex items-center gap-1.5 text-[10px] text-violet-400 border border-violet-500/20 rounded-full px-2 py-0.5">
|
| 69 |
+
<Cpu size={10} />
|
| 70 |
+
LinUCB Active
|
| 71 |
+
</div>
|
| 72 |
+
|
| 73 |
+
{/* Difficulty selector */}
|
| 74 |
+
<div className="flex items-center gap-1 border border-white/[0.06] rounded-lg p-0.5">
|
| 75 |
+
{DIFFICULTIES.map((d) => (
|
| 76 |
+
<button
|
| 77 |
+
key={d.id}
|
| 78 |
+
onClick={() => setTaskDifficulty(d.id)}
|
| 79 |
+
className={`text-[10px] font-semibold px-2 py-1 rounded transition-all ${
|
| 80 |
+
taskDifficulty === d.id
|
| 81 |
+
? `${d.color} border`
|
| 82 |
+
: 'text-gray-500 hover:text-gray-300 border border-transparent'
|
| 83 |
+
}`}
|
| 84 |
+
>
|
| 85 |
+
{d.label}
|
| 86 |
+
</button>
|
| 87 |
+
))}
|
| 88 |
+
</div>
|
| 89 |
+
|
| 90 |
+
{/* Theme toggle */}
|
| 91 |
+
<button
|
| 92 |
+
onClick={toggleTheme}
|
| 93 |
+
className="p-1.5 rounded-lg hover:bg-white/5 transition-colors theme-text-muted"
|
| 94 |
+
title={theme === 'dark' ? 'Switch to light' : 'Switch to dark'}
|
| 95 |
+
>
|
| 96 |
+
{theme === 'dark' ? <Sun size={14} /> : <Moon size={14} />}
|
| 97 |
+
</button>
|
| 98 |
+
|
| 99 |
+
{/* Mobile right sidebar toggle */}
|
| 100 |
+
<button
|
| 101 |
+
onClick={onToggleRight}
|
| 102 |
+
className="lg:hidden flex items-center gap-1 px-2 py-1.5 rounded-lg hover:bg-white/5 text-gray-400 hover:text-white transition-colors text-[10px]"
|
| 103 |
+
>
|
| 104 |
+
<span className="hidden sm:inline">GEPA</span>
|
| 105 |
+
<PanelRightOpen size={14} />
|
| 106 |
+
</button>
|
| 107 |
+
</div>
|
| 108 |
+
</header>
|
| 109 |
+
)
|
| 110 |
+
}
|
frontend/src/components/LeftSidebar.tsx
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useState } from 'react'
|
| 2 |
+
import { motion, AnimatePresence } from 'framer-motion'
|
| 3 |
+
import { Database, Table2, ChevronDown, ChevronRight, GitFork, ShoppingCart } from 'lucide-react'
|
| 4 |
+
import { useStore } from '../store/useStore'
|
| 5 |
+
import type { Difficulty } from '../lib/types'
|
| 6 |
+
|
| 7 |
+
const DIFFICULTY_CONFIG: Record<Difficulty, { label: string; bg: string; text: string; border: string }> = {
|
| 8 |
+
easy: { label: 'Easy', bg: 'bg-green-500/10', text: 'text-green-400', border: 'border-green-500/30' },
|
| 9 |
+
medium: { label: 'Medium', bg: 'bg-amber-500/10', text: 'text-amber-400', border: 'border-amber-500/30' },
|
| 10 |
+
hard: { label: 'Hard', bg: 'bg-red-500/10', text: 'text-red-400', border: 'border-red-500/30' },
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
export function LeftSidebar() {
|
| 14 |
+
const { tables, taskDifficulty, setTaskDifficulty, dbSeeded } = useStore()
|
| 15 |
+
const [tablesExpanded, setTablesExpanded] = useState(true)
|
| 16 |
+
|
| 17 |
+
const cfg = DIFFICULTY_CONFIG[taskDifficulty]
|
| 18 |
+
|
| 19 |
+
return (
|
| 20 |
+
<div className="flex flex-col gap-4 py-1">
|
| 21 |
+
{/* Task Difficulty */}
|
| 22 |
+
<section>
|
| 23 |
+
<div className="text-[10px] font-semibold text-gray-500 uppercase tracking-widest mb-2 flex items-center gap-1.5">
|
| 24 |
+
<GitFork size={10} className="text-violet-400" />
|
| 25 |
+
Task Difficulty
|
| 26 |
+
</div>
|
| 27 |
+
<div className="flex flex-col gap-1">
|
| 28 |
+
{(Object.keys(DIFFICULTY_CONFIG) as Difficulty[]).map((d) => {
|
| 29 |
+
const c = DIFFICULTY_CONFIG[d]
|
| 30 |
+
const active = d === taskDifficulty
|
| 31 |
+
return (
|
| 32 |
+
<button
|
| 33 |
+
key={d}
|
| 34 |
+
onClick={() => setTaskDifficulty(d)}
|
| 35 |
+
className={`flex items-center justify-between px-3 py-2 rounded-lg border text-xs font-medium transition-all ${
|
| 36 |
+
active
|
| 37 |
+
? `${c.bg} ${c.text} ${c.border}`
|
| 38 |
+
: 'border-transparent text-gray-500 hover:text-gray-300 hover:bg-white/5'
|
| 39 |
+
}`}
|
| 40 |
+
>
|
| 41 |
+
<span>{c.label}</span>
|
| 42 |
+
{active && (
|
| 43 |
+
<span className={`text-[9px] font-mono ${c.text} opacity-70`}>selected</span>
|
| 44 |
+
)}
|
| 45 |
+
</button>
|
| 46 |
+
)
|
| 47 |
+
})}
|
| 48 |
+
</div>
|
| 49 |
+
</section>
|
| 50 |
+
|
| 51 |
+
{/* Schema Tables */}
|
| 52 |
+
<section>
|
| 53 |
+
<button
|
| 54 |
+
className="text-[10px] font-semibold text-gray-500 uppercase tracking-widest mb-2 flex items-center gap-1.5 w-full"
|
| 55 |
+
onClick={() => setTablesExpanded((v) => !v)}
|
| 56 |
+
>
|
| 57 |
+
<Database size={10} className="text-blue-400" />
|
| 58 |
+
<span className="flex-1 text-left">Database Schema</span>
|
| 59 |
+
{tablesExpanded ? <ChevronDown size={10} /> : <ChevronRight size={10} />}
|
| 60 |
+
</button>
|
| 61 |
+
<AnimatePresence>
|
| 62 |
+
{tablesExpanded && (
|
| 63 |
+
<motion.div
|
| 64 |
+
initial={{ opacity: 0, height: 0 }}
|
| 65 |
+
animate={{ opacity: 1, height: 'auto' }}
|
| 66 |
+
exit={{ opacity: 0, height: 0 }}
|
| 67 |
+
className="overflow-hidden"
|
| 68 |
+
>
|
| 69 |
+
{dbSeeded && tables.length > 0 ? (
|
| 70 |
+
<div className="flex flex-col gap-1">
|
| 71 |
+
{tables.map((t) => (
|
| 72 |
+
<div
|
| 73 |
+
key={t.name}
|
| 74 |
+
className="flex items-center justify-between px-2.5 py-1.5 rounded-lg border border-white/[0.04] bg-white/[0.02] hover:bg-white/[0.04] transition-colors"
|
| 75 |
+
>
|
| 76 |
+
<div className="flex items-center gap-1.5">
|
| 77 |
+
<Table2 size={10} className="text-blue-400 shrink-0" />
|
| 78 |
+
<span className="text-xs text-gray-300 font-mono">{t.name}</span>
|
| 79 |
+
</div>
|
| 80 |
+
<span className="text-[9px] text-gray-600 font-mono tabular-nums">
|
| 81 |
+
{t.rows.toLocaleString()}
|
| 82 |
+
</span>
|
| 83 |
+
</div>
|
| 84 |
+
))}
|
| 85 |
+
</div>
|
| 86 |
+
) : (
|
| 87 |
+
<div className="flex flex-col gap-1">
|
| 88 |
+
{[120, 80, 95, 60, 70].map((w, i) => (
|
| 89 |
+
<div
|
| 90 |
+
key={i}
|
| 91 |
+
className="flex items-center justify-between px-2.5 py-1.5 rounded-lg border border-white/[0.04] bg-white/[0.02]"
|
| 92 |
+
>
|
| 93 |
+
<div
|
| 94 |
+
className="h-2 rounded bg-white/10 animate-pulse"
|
| 95 |
+
style={{ width: w }}
|
| 96 |
+
/>
|
| 97 |
+
<div className="h-2 w-8 rounded bg-white/10 animate-pulse" />
|
| 98 |
+
</div>
|
| 99 |
+
))}
|
| 100 |
+
</div>
|
| 101 |
+
)}
|
| 102 |
+
</motion.div>
|
| 103 |
+
)}
|
| 104 |
+
</AnimatePresence>
|
| 105 |
+
</section>
|
| 106 |
+
|
| 107 |
+
{/* Business Context */}
|
| 108 |
+
<section>
|
| 109 |
+
<div className="text-[10px] font-semibold text-gray-500 uppercase tracking-widest mb-2 flex items-center gap-1.5">
|
| 110 |
+
<ShoppingCart size={10} className="text-orange-400" />
|
| 111 |
+
Business Context
|
| 112 |
+
</div>
|
| 113 |
+
<div
|
| 114 |
+
className="rounded-xl border border-white/[0.05] p-3 text-[11px] text-gray-500 leading-relaxed"
|
| 115 |
+
style={{ background: 'var(--bg-card)' }}
|
| 116 |
+
>
|
| 117 |
+
<p className="mb-2 text-gray-400 font-medium">E-Commerce Marketplace</p>
|
| 118 |
+
<p>
|
| 119 |
+
Multi-vendor marketplace with products, orders, sellers, users, and reviews.
|
| 120 |
+
Supports complex analytical queries across sales, inventory, and user behavior.
|
| 121 |
+
</p>
|
| 122 |
+
<div className="mt-2 flex flex-wrap gap-1">
|
| 123 |
+
{['Products', 'Orders', 'Sellers', 'Users', 'Reviews', 'Categories'].map((t) => (
|
| 124 |
+
<span
|
| 125 |
+
key={t}
|
| 126 |
+
className="text-[9px] px-1.5 py-0.5 rounded border border-white/[0.06] text-gray-600"
|
| 127 |
+
>
|
| 128 |
+
{t}
|
| 129 |
+
</span>
|
| 130 |
+
))}
|
| 131 |
+
</div>
|
| 132 |
+
</div>
|
| 133 |
+
</section>
|
| 134 |
+
|
| 135 |
+
{/* Current task badge */}
|
| 136 |
+
<section>
|
| 137 |
+
<div
|
| 138 |
+
className={`rounded-xl border ${cfg.border} ${cfg.bg} p-3 flex flex-col gap-1.5`}
|
| 139 |
+
>
|
| 140 |
+
<div className="flex items-center justify-between">
|
| 141 |
+
<span className={`text-[10px] font-semibold uppercase tracking-wider ${cfg.text}`}>
|
| 142 |
+
Current Task
|
| 143 |
+
</span>
|
| 144 |
+
<span className={`text-[10px] font-mono ${cfg.text}`}>{cfg.label}</span>
|
| 145 |
+
</div>
|
| 146 |
+
<p className="text-[11px] text-gray-400 leading-relaxed">
|
| 147 |
+
{taskDifficulty === 'easy'
|
| 148 |
+
? 'Simple SELECT queries, basic filtering and aggregation'
|
| 149 |
+
: taskDifficulty === 'medium'
|
| 150 |
+
? 'Multi-table JOINs, GROUP BY, subqueries, window functions'
|
| 151 |
+
: 'Complex CTEs, rolling aggregations, cohort analysis, ranking'}
|
| 152 |
+
</p>
|
| 153 |
+
</div>
|
| 154 |
+
</section>
|
| 155 |
+
</div>
|
| 156 |
+
)
|
| 157 |
+
}
|
frontend/src/components/PerformanceGraph.tsx
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect } from 'react'
|
| 2 |
+
import {
|
| 3 |
+
LineChart, Line, BarChart, Bar,
|
| 4 |
+
XAxis, YAxis, CartesianGrid, Tooltip,
|
| 5 |
+
ResponsiveContainer, ReferenceLine,
|
| 6 |
+
} from 'recharts'
|
| 7 |
+
import { TrendingUp, Loader2, RefreshCw } from 'lucide-react'
|
| 8 |
+
import { useStore } from '../store/useStore'
|
| 9 |
+
import { fetchRLState } from '../lib/api'
|
| 10 |
+
|
| 11 |
+
const CustomTooltip = ({
|
| 12 |
+
active,
|
| 13 |
+
payload,
|
| 14 |
+
label,
|
| 15 |
+
}: {
|
| 16 |
+
active?: boolean
|
| 17 |
+
payload?: { value: number; name: string; color: string }[]
|
| 18 |
+
label?: string | number
|
| 19 |
+
}) => {
|
| 20 |
+
if (active && payload?.length) {
|
| 21 |
+
return (
|
| 22 |
+
<div
|
| 23 |
+
className="border border-white/10 rounded-lg px-3 py-2 text-xs"
|
| 24 |
+
style={{ background: '#1a1a2e' }}
|
| 25 |
+
>
|
| 26 |
+
<p className="text-gray-400 mb-1">#{label}</p>
|
| 27 |
+
{payload.map((p) => (
|
| 28 |
+
<p key={p.name} style={{ color: p.color }}>
|
| 29 |
+
{p.name}: <span className="font-semibold">{p.value}</span>
|
| 30 |
+
</p>
|
| 31 |
+
))}
|
| 32 |
+
</div>
|
| 33 |
+
)
|
| 34 |
+
}
|
| 35 |
+
return null
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
export function PerformanceGraph() {
|
| 39 |
+
const { rlState, setRlState } = useStore()
|
| 40 |
+
|
| 41 |
+
const load = async () => {
|
| 42 |
+
try {
|
| 43 |
+
const data = await fetchRLState()
|
| 44 |
+
setRlState(data)
|
| 45 |
+
} catch {
|
| 46 |
+
// noop β backend might not be up
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
useEffect(() => {
|
| 51 |
+
void load()
|
| 52 |
+
const interval = setInterval(() => void load(), 10_000)
|
| 53 |
+
return () => clearInterval(interval)
|
| 54 |
+
// eslint-disable-next-line react-hooks/exhaustive-deps
|
| 55 |
+
}, [])
|
| 56 |
+
|
| 57 |
+
if (!rlState) {
|
| 58 |
+
return (
|
| 59 |
+
<div className="flex flex-col items-center justify-center h-40 text-gray-600 gap-2">
|
| 60 |
+
<TrendingUp size={24} className="text-gray-700" />
|
| 61 |
+
<p className="text-[11px] text-center">
|
| 62 |
+
RL metrics appear after agent episodes
|
| 63 |
+
</p>
|
| 64 |
+
<Loader2 size={14} className="animate-spin text-gray-700" />
|
| 65 |
+
</div>
|
| 66 |
+
)
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
const { totalEpisodes, successRate, currentAlpha, episodes, actionDistribution } = rlState
|
| 70 |
+
|
| 71 |
+
return (
|
| 72 |
+
<div className="flex flex-col gap-3">
|
| 73 |
+
{/* Stats row */}
|
| 74 |
+
<div className="grid grid-cols-3 gap-1.5">
|
| 75 |
+
{[
|
| 76 |
+
{ label: 'Episodes', value: totalEpisodes, color: 'text-blue-400' },
|
| 77 |
+
{ label: 'Success', value: `${(successRate * 100).toFixed(0)}%`, color: 'text-green-400' },
|
| 78 |
+
{ label: 'Alpha', value: currentAlpha.toFixed(3), color: 'text-orange-400' },
|
| 79 |
+
].map((s) => (
|
| 80 |
+
<div
|
| 81 |
+
key={s.label}
|
| 82 |
+
className="bg-white/5 rounded-xl p-2 text-center"
|
| 83 |
+
>
|
| 84 |
+
<div className={`text-sm font-bold font-mono ${s.color}`}>{s.value}</div>
|
| 85 |
+
<div className="text-[9px] text-gray-500 mt-0.5">{s.label}</div>
|
| 86 |
+
</div>
|
| 87 |
+
))}
|
| 88 |
+
</div>
|
| 89 |
+
|
| 90 |
+
{/* Reward per episode */}
|
| 91 |
+
{episodes.length > 0 && (
|
| 92 |
+
<div>
|
| 93 |
+
<div className="flex items-center justify-between mb-1.5">
|
| 94 |
+
<div className="text-[10px] text-gray-500 font-medium">Reward per Episode</div>
|
| 95 |
+
<button
|
| 96 |
+
onClick={() => void load()}
|
| 97 |
+
className="p-1 rounded hover:bg-white/5 text-gray-600 hover:text-gray-400 transition-colors"
|
| 98 |
+
title="Refresh"
|
| 99 |
+
>
|
| 100 |
+
<RefreshCw size={10} />
|
| 101 |
+
</button>
|
| 102 |
+
</div>
|
| 103 |
+
<ResponsiveContainer width="100%" height={110}>
|
| 104 |
+
<LineChart data={episodes} margin={{ top: 4, right: 4, bottom: 0, left: -20 }}>
|
| 105 |
+
<CartesianGrid strokeDasharray="3 3" stroke="#ffffff08" />
|
| 106 |
+
<XAxis dataKey="episode" tick={{ fontSize: 9, fill: '#6b7280' }} />
|
| 107 |
+
<YAxis domain={[-1, 1]} tick={{ fontSize: 9, fill: '#6b7280' }} />
|
| 108 |
+
<Tooltip content={<CustomTooltip />} />
|
| 109 |
+
<ReferenceLine y={0} stroke="#ffffff20" strokeDasharray="3 3" />
|
| 110 |
+
<Line
|
| 111 |
+
type="monotone"
|
| 112 |
+
dataKey="totalReward"
|
| 113 |
+
name="Reward"
|
| 114 |
+
stroke="#f97316"
|
| 115 |
+
strokeWidth={2}
|
| 116 |
+
dot={episodes.length < 30 ? { fill: '#f97316', r: 2 } : false}
|
| 117 |
+
activeDot={{ r: 4 }}
|
| 118 |
+
/>
|
| 119 |
+
</LineChart>
|
| 120 |
+
</ResponsiveContainer>
|
| 121 |
+
</div>
|
| 122 |
+
)}
|
| 123 |
+
|
| 124 |
+
{/* Action distribution */}
|
| 125 |
+
{actionDistribution.length > 0 && (
|
| 126 |
+
<div>
|
| 127 |
+
<div className="text-[10px] text-gray-500 mb-1.5 font-medium">
|
| 128 |
+
LinUCB Action Distribution
|
| 129 |
+
</div>
|
| 130 |
+
<ResponsiveContainer width="100%" height={90}>
|
| 131 |
+
<BarChart
|
| 132 |
+
data={actionDistribution}
|
| 133 |
+
margin={{ top: 4, right: 4, bottom: 0, left: -20 }}
|
| 134 |
+
>
|
| 135 |
+
<CartesianGrid strokeDasharray="3 3" stroke="#ffffff08" />
|
| 136 |
+
<XAxis
|
| 137 |
+
dataKey="action"
|
| 138 |
+
tick={{ fontSize: 8, fill: '#6b7280' }}
|
| 139 |
+
tickFormatter={(v: string) => v.replace('FIX_', '').slice(0, 6)}
|
| 140 |
+
/>
|
| 141 |
+
<YAxis tick={{ fontSize: 9, fill: '#6b7280' }} />
|
| 142 |
+
<Tooltip content={<CustomTooltip />} />
|
| 143 |
+
<Bar dataKey="count" name="Uses" fill="#8b5cf6" radius={[3, 3, 0, 0]} />
|
| 144 |
+
</BarChart>
|
| 145 |
+
</ResponsiveContainer>
|
| 146 |
+
</div>
|
| 147 |
+
)}
|
| 148 |
+
|
| 149 |
+
{/* Success rate line */}
|
| 150 |
+
{episodes.length >= 3 && (
|
| 151 |
+
<div>
|
| 152 |
+
<div className="text-[10px] text-gray-500 mb-1.5 font-medium">
|
| 153 |
+
Rolling Success Rate
|
| 154 |
+
</div>
|
| 155 |
+
<ResponsiveContainer width="100%" height={80}>
|
| 156 |
+
<LineChart data={episodes} margin={{ top: 4, right: 4, bottom: 0, left: -20 }}>
|
| 157 |
+
<CartesianGrid strokeDasharray="3 3" stroke="#ffffff08" />
|
| 158 |
+
<XAxis dataKey="episode" tick={{ fontSize: 9, fill: '#6b7280' }} />
|
| 159 |
+
<YAxis domain={[0, 1]} tick={{ fontSize: 9, fill: '#6b7280' }} />
|
| 160 |
+
<Tooltip content={<CustomTooltip />} />
|
| 161 |
+
<Line
|
| 162 |
+
type="monotone"
|
| 163 |
+
dataKey="successRate"
|
| 164 |
+
name="Success"
|
| 165 |
+
stroke="#22c55e"
|
| 166 |
+
strokeWidth={2}
|
| 167 |
+
dot={false}
|
| 168 |
+
/>
|
| 169 |
+
</LineChart>
|
| 170 |
+
</ResponsiveContainer>
|
| 171 |
+
</div>
|
| 172 |
+
)}
|
| 173 |
+
</div>
|
| 174 |
+
)
|
| 175 |
+
}
|
frontend/src/components/PromptEvolution.tsx
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useState, useEffect } from 'react'
|
| 2 |
+
import { motion, AnimatePresence } from 'framer-motion'
|
| 3 |
+
import { Brain, ChevronDown, ChevronUp, Zap, History } from 'lucide-react'
|
| 4 |
+
import { useStore } from '../store/useStore'
|
| 5 |
+
import { fetchPromptHistory } from '../lib/api'
|
| 6 |
+
|
| 7 |
+
const SEED_PROMPT = `You are a SQL expert. Given a natural language question and a SQLite database schema, write a correct SQL query.
|
| 8 |
+
|
| 9 |
+
Rules:
|
| 10 |
+
- Output ONLY the SQL query, nothing else
|
| 11 |
+
- No markdown, no code fences, no explanation
|
| 12 |
+
- Use SQLite syntax
|
| 13 |
+
- Always qualify column names with table aliases when using JOINs`
|
| 14 |
+
|
| 15 |
+
export function PromptEvolution() {
|
| 16 |
+
const { currentPrompt, promptGeneration, promptHistory, setPromptData } = useStore()
|
| 17 |
+
const [expanded, setExpanded] = useState(false)
|
| 18 |
+
const [historyExpanded, setHistoryExpanded] = useState(false)
|
| 19 |
+
const [loading, setLoading] = useState(false)
|
| 20 |
+
|
| 21 |
+
const prompt = currentPrompt || SEED_PROMPT
|
| 22 |
+
const generation = promptGeneration
|
| 23 |
+
|
| 24 |
+
const loadHistory = async () => {
|
| 25 |
+
setLoading(true)
|
| 26 |
+
try {
|
| 27 |
+
const data = await fetchPromptHistory()
|
| 28 |
+
setPromptData(data)
|
| 29 |
+
} catch {
|
| 30 |
+
// noop
|
| 31 |
+
} finally {
|
| 32 |
+
setLoading(false)
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
useEffect(() => {
|
| 37 |
+
void loadHistory()
|
| 38 |
+
// eslint-disable-next-line react-hooks/exhaustive-deps
|
| 39 |
+
}, [])
|
| 40 |
+
|
| 41 |
+
return (
|
| 42 |
+
<div className="flex flex-col gap-2">
|
| 43 |
+
{/* Header */}
|
| 44 |
+
<button
|
| 45 |
+
onClick={() => setExpanded((v) => !v)}
|
| 46 |
+
className="flex items-center justify-between w-full group"
|
| 47 |
+
>
|
| 48 |
+
<div className="flex items-center gap-2">
|
| 49 |
+
<Brain size={14} className="text-violet-400" />
|
| 50 |
+
<span className="text-xs font-semibold text-white/70">System Prompt</span>
|
| 51 |
+
{generation > 0 ? (
|
| 52 |
+
<span className="text-[10px] bg-violet-500/20 text-violet-300 border border-violet-500/30 rounded-full px-2 py-0.5">
|
| 53 |
+
Gen {generation} Β· Optimized
|
| 54 |
+
</span>
|
| 55 |
+
) : (
|
| 56 |
+
<span className="text-[10px] bg-white/5 text-gray-500 rounded-full px-2 py-0.5">
|
| 57 |
+
Seed
|
| 58 |
+
</span>
|
| 59 |
+
)}
|
| 60 |
+
</div>
|
| 61 |
+
{expanded ? (
|
| 62 |
+
<ChevronUp size={13} className="text-gray-500" />
|
| 63 |
+
) : (
|
| 64 |
+
<ChevronDown size={13} className="text-gray-500" />
|
| 65 |
+
)}
|
| 66 |
+
</button>
|
| 67 |
+
|
| 68 |
+
<AnimatePresence>
|
| 69 |
+
{expanded && (
|
| 70 |
+
<motion.div
|
| 71 |
+
initial={{ opacity: 0, height: 0 }}
|
| 72 |
+
animate={{ opacity: 1, height: 'auto' }}
|
| 73 |
+
exit={{ opacity: 0, height: 0 }}
|
| 74 |
+
transition={{ duration: 0.2 }}
|
| 75 |
+
className="overflow-hidden"
|
| 76 |
+
>
|
| 77 |
+
{/* Prompt preview */}
|
| 78 |
+
<div className="max-h-40 overflow-y-auto">
|
| 79 |
+
<pre className="text-[11px] font-mono text-violet-200/70 bg-violet-950/30 rounded-xl p-3 border border-violet-500/20 whitespace-pre-wrap leading-relaxed">
|
| 80 |
+
{prompt}
|
| 81 |
+
</pre>
|
| 82 |
+
</div>
|
| 83 |
+
|
| 84 |
+
{/* History button */}
|
| 85 |
+
{promptHistory.length > 0 && (
|
| 86 |
+
<button
|
| 87 |
+
onClick={() => setHistoryExpanded((v) => !v)}
|
| 88 |
+
className="mt-2 w-full flex items-center justify-center gap-2 px-3 py-2 text-xs font-medium bg-violet-600/15 text-violet-300 border border-violet-500/25 rounded-xl hover:bg-violet-600/25 hover:border-violet-500/40 transition-all"
|
| 89 |
+
>
|
| 90 |
+
<History size={12} />
|
| 91 |
+
{historyExpanded ? 'Hide' : 'View'} Evolution History
|
| 92 |
+
<span className="text-[10px] text-violet-400/60 ml-1">
|
| 93 |
+
({promptHistory.length} gen{promptHistory.length !== 1 ? 's' : ''})
|
| 94 |
+
</span>
|
| 95 |
+
</button>
|
| 96 |
+
)}
|
| 97 |
+
|
| 98 |
+
{/* Generation history */}
|
| 99 |
+
<AnimatePresence>
|
| 100 |
+
{historyExpanded && promptHistory.length > 0 && (
|
| 101 |
+
<motion.div
|
| 102 |
+
initial={{ height: 0, opacity: 0 }}
|
| 103 |
+
animate={{ height: 'auto', opacity: 1 }}
|
| 104 |
+
exit={{ height: 0, opacity: 0 }}
|
| 105 |
+
transition={{ duration: 0.15 }}
|
| 106 |
+
className="overflow-hidden mt-2"
|
| 107 |
+
>
|
| 108 |
+
<div className="flex flex-col gap-1.5">
|
| 109 |
+
<div className="text-[10px] text-gray-500 font-medium flex items-center gap-1">
|
| 110 |
+
<Zap size={10} className="text-violet-400" />
|
| 111 |
+
Optimization History
|
| 112 |
+
</div>
|
| 113 |
+
{promptHistory.map((snap) => (
|
| 114 |
+
<div
|
| 115 |
+
key={snap.generation}
|
| 116 |
+
className="border border-white/5 rounded-xl p-2.5 hover:border-white/10 hover:bg-white/[0.02] transition-all"
|
| 117 |
+
>
|
| 118 |
+
<div className="flex items-center justify-between mb-1">
|
| 119 |
+
<span className="text-[10px] font-semibold text-violet-400">
|
| 120 |
+
Generation {snap.generation}
|
| 121 |
+
</span>
|
| 122 |
+
<span className="text-[10px] font-mono text-green-400">
|
| 123 |
+
{(snap.score * 100).toFixed(0)}%
|
| 124 |
+
</span>
|
| 125 |
+
</div>
|
| 126 |
+
<p className="text-[10px] text-gray-400 leading-relaxed line-clamp-2">
|
| 127 |
+
{snap.summary}
|
| 128 |
+
</p>
|
| 129 |
+
<p className="text-[9px] text-gray-600 mt-1">{snap.timestamp}</p>
|
| 130 |
+
</div>
|
| 131 |
+
))}
|
| 132 |
+
</div>
|
| 133 |
+
</motion.div>
|
| 134 |
+
)}
|
| 135 |
+
</AnimatePresence>
|
| 136 |
+
|
| 137 |
+
{loading && (
|
| 138 |
+
<div className="flex items-center gap-2 text-[10px] text-gray-500 mt-2 px-1">
|
| 139 |
+
<span className="w-3 h-3 border border-violet-500/40 border-t-violet-400 rounded-full animate-spin inline-block" />
|
| 140 |
+
Loading history...
|
| 141 |
+
</div>
|
| 142 |
+
)}
|
| 143 |
+
</motion.div>
|
| 144 |
+
)}
|
| 145 |
+
</AnimatePresence>
|
| 146 |
+
</div>
|
| 147 |
+
)
|
| 148 |
+
}
|
frontend/src/components/ResultsTable.tsx
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
const MAX_ROWS = 10
|
| 2 |
+
const MAX_CELL_LEN = 30
|
| 3 |
+
|
| 4 |
+
function truncate(val: unknown): string {
|
| 5 |
+
const s = val === null || val === undefined ? 'null' : String(val)
|
| 6 |
+
return s.length > MAX_CELL_LEN ? s.slice(0, MAX_CELL_LEN) + 'β¦' : s
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
interface ResultsTableProps {
|
| 10 |
+
rows: Record<string, unknown>[]
|
| 11 |
+
rowCount: number
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
export function ResultsTable({ rows, rowCount }: ResultsTableProps) {
|
| 15 |
+
if (rows.length === 0) {
|
| 16 |
+
return (
|
| 17 |
+
<div className="text-xs text-gray-500 italic px-3 py-2 border border-white/[0.06] rounded-xl">
|
| 18 |
+
No rows returned.
|
| 19 |
+
</div>
|
| 20 |
+
)
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
const columns = Object.keys(rows[0])
|
| 24 |
+
const displayRows = rows.slice(0, MAX_ROWS)
|
| 25 |
+
|
| 26 |
+
return (
|
| 27 |
+
<div className="overflow-auto max-h-60 rounded-xl border border-white/[0.06]" style={{ fontSize: 11 }}>
|
| 28 |
+
{rowCount > MAX_ROWS && (
|
| 29 |
+
<div className="px-3 py-1 text-[10px] text-amber-400/70 bg-amber-500/5 border-b border-amber-500/10 shrink-0">
|
| 30 |
+
Showing {MAX_ROWS} of {rowCount} rows
|
| 31 |
+
</div>
|
| 32 |
+
)}
|
| 33 |
+
<table className="w-full font-mono border-collapse">
|
| 34 |
+
<thead>
|
| 35 |
+
<tr
|
| 36 |
+
className="border-b border-white/[0.06] sticky top-0"
|
| 37 |
+
style={{ background: 'var(--bg-tertiary)' }}
|
| 38 |
+
>
|
| 39 |
+
{columns.map((col) => (
|
| 40 |
+
<th
|
| 41 |
+
key={col}
|
| 42 |
+
className="px-3 py-1.5 text-left text-[10px] font-semibold text-gray-500 uppercase tracking-wider whitespace-nowrap"
|
| 43 |
+
>
|
| 44 |
+
{col}
|
| 45 |
+
</th>
|
| 46 |
+
))}
|
| 47 |
+
</tr>
|
| 48 |
+
</thead>
|
| 49 |
+
<tbody>
|
| 50 |
+
{displayRows.map((row, i) => (
|
| 51 |
+
<tr
|
| 52 |
+
key={i}
|
| 53 |
+
className="border-b border-white/[0.03] hover:bg-white/[0.02] transition-colors"
|
| 54 |
+
>
|
| 55 |
+
{columns.map((col) => (
|
| 56 |
+
<td
|
| 57 |
+
key={col}
|
| 58 |
+
className={`px-3 py-1.5 whitespace-nowrap ${
|
| 59 |
+
row[col] === null ? 'text-gray-600 italic' : 'text-gray-300'
|
| 60 |
+
}`}
|
| 61 |
+
title={row[col] !== null ? String(row[col]) : undefined}
|
| 62 |
+
>
|
| 63 |
+
{truncate(row[col])}
|
| 64 |
+
</td>
|
| 65 |
+
))}
|
| 66 |
+
</tr>
|
| 67 |
+
))}
|
| 68 |
+
</tbody>
|
| 69 |
+
</table>
|
| 70 |
+
<div
|
| 71 |
+
className="px-3 py-1 text-[10px] text-gray-600 border-t border-white/[0.04]"
|
| 72 |
+
style={{ background: 'var(--bg-tertiary)' }}
|
| 73 |
+
>
|
| 74 |
+
Showing {displayRows.length} of {rowCount} rows
|
| 75 |
+
</div>
|
| 76 |
+
</div>
|
| 77 |
+
)
|
| 78 |
+
}
|
frontend/src/components/RightSidebar.tsx
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { Zap, Brain } from 'lucide-react'
|
| 2 |
+
import { PromptEvolution } from './PromptEvolution'
|
| 3 |
+
import { PerformanceGraph } from './PerformanceGraph'
|
| 4 |
+
|
| 5 |
+
export function RightSidebar() {
|
| 6 |
+
return (
|
| 7 |
+
<div className="flex flex-col h-full overflow-y-auto">
|
| 8 |
+
{/* GEPA Section */}
|
| 9 |
+
<div className="p-4 border-b border-white/[0.06] shrink-0">
|
| 10 |
+
<div className="flex items-center gap-2 text-[10px] font-semibold text-gray-500 uppercase tracking-widest mb-3">
|
| 11 |
+
<Brain size={10} className="text-violet-400" />
|
| 12 |
+
GEPA Prompt Evolution
|
| 13 |
+
</div>
|
| 14 |
+
<PromptEvolution />
|
| 15 |
+
</div>
|
| 16 |
+
|
| 17 |
+
{/* RL Charts */}
|
| 18 |
+
<div className="p-4 flex-1 overflow-y-auto">
|
| 19 |
+
<div className="flex items-center gap-2 text-[10px] font-semibold text-gray-500 uppercase tracking-widest mb-3">
|
| 20 |
+
<Zap size={10} className="text-violet-400" />
|
| 21 |
+
RL Learning Progress
|
| 22 |
+
</div>
|
| 23 |
+
<PerformanceGraph />
|
| 24 |
+
</div>
|
| 25 |
+
</div>
|
| 26 |
+
)
|
| 27 |
+
}
|
frontend/src/index.css
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@tailwind base;
|
| 2 |
+
@tailwind components;
|
| 3 |
+
@tailwind utilities;
|
| 4 |
+
|
| 5 |
+
/* βββ Theme Variables ββββββββββββββββββββββββββββββββββββββββββββ */
|
| 6 |
+
|
| 7 |
+
:root {
|
| 8 |
+
--bg-primary: #08080d;
|
| 9 |
+
--bg-secondary: #09090f;
|
| 10 |
+
--bg-tertiary: #0a0a12;
|
| 11 |
+
--bg-card: #0e0e16;
|
| 12 |
+
--bg-input: rgba(255, 255, 255, 0.03);
|
| 13 |
+
--bg-hover: rgba(255, 255, 255, 0.02);
|
| 14 |
+
--bg-hover-strong: rgba(255, 255, 255, 0.05);
|
| 15 |
+
|
| 16 |
+
--text-primary: #ffffff;
|
| 17 |
+
--text-secondary: rgba(255, 255, 255, 0.7);
|
| 18 |
+
--text-muted: #6b7280;
|
| 19 |
+
--text-dim: #4b5563;
|
| 20 |
+
|
| 21 |
+
--border-color: rgba(255, 255, 255, 0.06);
|
| 22 |
+
--border-hover: rgba(255, 255, 255, 0.12);
|
| 23 |
+
|
| 24 |
+
--accent-violet: #8b5cf6;
|
| 25 |
+
--accent-green: #22c55e;
|
| 26 |
+
--accent-orange: #f97316;
|
| 27 |
+
--accent-red: #ef4444;
|
| 28 |
+
--accent-blue: #3b82f6;
|
| 29 |
+
|
| 30 |
+
--background: var(--bg-primary);
|
| 31 |
+
--foreground: var(--text-primary);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
[data-theme="light"] {
|
| 35 |
+
--bg-primary: #f5f6f8;
|
| 36 |
+
--bg-secondary: #ffffff;
|
| 37 |
+
--bg-tertiary: #eef0f3;
|
| 38 |
+
--bg-card: #ffffff;
|
| 39 |
+
--bg-input: rgba(0, 0, 0, 0.04);
|
| 40 |
+
--bg-hover: rgba(0, 0, 0, 0.02);
|
| 41 |
+
--bg-hover-strong: rgba(0, 0, 0, 0.05);
|
| 42 |
+
|
| 43 |
+
--text-primary: #111827;
|
| 44 |
+
--text-secondary: #374151;
|
| 45 |
+
--text-muted: #6b7280;
|
| 46 |
+
--text-dim: #9ca3af;
|
| 47 |
+
|
| 48 |
+
--border-color: rgba(0, 0, 0, 0.1);
|
| 49 |
+
--border-hover: rgba(0, 0, 0, 0.2);
|
| 50 |
+
|
| 51 |
+
--background: var(--bg-primary);
|
| 52 |
+
--foreground: var(--text-primary);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
/* βββ Base βββββββββββββββββββββββββββββββββββββββββββββββββββββββ */
|
| 56 |
+
|
| 57 |
+
* {
|
| 58 |
+
box-sizing: border-box;
|
| 59 |
+
margin: 0;
|
| 60 |
+
padding: 0;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
html,
|
| 64 |
+
body,
|
| 65 |
+
#root {
|
| 66 |
+
height: 100%;
|
| 67 |
+
width: 100%;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
body {
|
| 71 |
+
background: var(--bg-primary);
|
| 72 |
+
color: var(--text-primary);
|
| 73 |
+
font-family: ui-monospace, 'SF Mono', Consolas, 'Liberation Mono', monospace;
|
| 74 |
+
-webkit-font-smoothing: antialiased;
|
| 75 |
+
-moz-osx-font-smoothing: grayscale;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
/* βββ Theme Utility Classes ββββββββββββββββββββββββββββββββββββββββ */
|
| 79 |
+
|
| 80 |
+
.theme-bg-primary { background-color: var(--bg-primary) !important; }
|
| 81 |
+
.theme-bg-secondary { background-color: var(--bg-secondary) !important; }
|
| 82 |
+
.theme-bg-tertiary { background-color: var(--bg-tertiary) !important; }
|
| 83 |
+
.theme-bg-card { background-color: var(--bg-card) !important; }
|
| 84 |
+
.theme-text-primary { color: var(--text-primary) !important; }
|
| 85 |
+
.theme-text-secondary { color: var(--text-secondary) !important; }
|
| 86 |
+
.theme-text-muted { color: var(--text-muted) !important; }
|
| 87 |
+
.theme-border { border-color: var(--border-color) !important; }
|
| 88 |
+
.theme-border-hover { border-color: var(--border-hover) !important; }
|
| 89 |
+
|
| 90 |
+
/* βββ Scrollbars βββββββββββββββββββββββββββββββββββββββββββββββββββ */
|
| 91 |
+
|
| 92 |
+
.scrollbar-none {
|
| 93 |
+
-ms-overflow-style: none;
|
| 94 |
+
scrollbar-width: none;
|
| 95 |
+
}
|
| 96 |
+
.scrollbar-none::-webkit-scrollbar {
|
| 97 |
+
display: none;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
::-webkit-scrollbar {
|
| 101 |
+
width: 4px;
|
| 102 |
+
height: 4px;
|
| 103 |
+
}
|
| 104 |
+
::-webkit-scrollbar-track {
|
| 105 |
+
background: transparent;
|
| 106 |
+
}
|
| 107 |
+
::-webkit-scrollbar-thumb {
|
| 108 |
+
background: rgba(255, 255, 255, 0.1);
|
| 109 |
+
border-radius: 4px;
|
| 110 |
+
}
|
| 111 |
+
[data-theme="light"] ::-webkit-scrollbar-thumb {
|
| 112 |
+
background: rgba(0, 0, 0, 0.15);
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
/* βββ SQL Syntax Highlighting βββββββββββββββββββββββββββββββββββββ */
|
| 116 |
+
|
| 117 |
+
.sql-keyword { color: #a78bfa; font-weight: 600; }
|
| 118 |
+
.sql-function { color: #60a5fa; }
|
| 119 |
+
.sql-string { color: #34d399; }
|
| 120 |
+
.sql-number { color: #f97316; }
|
| 121 |
+
.sql-comment { color: #6b7280; font-style: italic; }
|
| 122 |
+
.sql-operator { color: #e5e7eb; }
|
| 123 |
+
|
| 124 |
+
/* βββ Blinking cursor ββββββββββββββββββββββββββββββββββββββββββββ */
|
| 125 |
+
|
| 126 |
+
@keyframes blink {
|
| 127 |
+
0%, 100% { opacity: 1; }
|
| 128 |
+
50% { opacity: 0; }
|
| 129 |
+
}
|
| 130 |
+
.cursor-blink {
|
| 131 |
+
display: inline-block;
|
| 132 |
+
width: 2px;
|
| 133 |
+
height: 1em;
|
| 134 |
+
background: currentColor;
|
| 135 |
+
animation: blink 1s step-end infinite;
|
| 136 |
+
vertical-align: text-bottom;
|
| 137 |
+
margin-left: 1px;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
/* βββ Reward pulse animation βββββββββββββββββββββββββββββββββββββββ */
|
| 141 |
+
|
| 142 |
+
@keyframes rewardPulse {
|
| 143 |
+
0% { transform: scale(1); opacity: 0.7; }
|
| 144 |
+
50% { transform: scale(1.15); opacity: 1; }
|
| 145 |
+
100% { transform: scale(1); opacity: 1; }
|
| 146 |
+
}
|
| 147 |
+
.reward-pulse {
|
| 148 |
+
animation: rewardPulse 0.5s ease-out;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
/* βββ Optimizing banner ββββββββββββββββββββββββββββββββββββββββββββ */
|
| 152 |
+
|
| 153 |
+
@keyframes shimmer {
|
| 154 |
+
0% { background-position: -200% 0; }
|
| 155 |
+
100% { background-position: 200% 0; }
|
| 156 |
+
}
|
| 157 |
+
.shimmer-banner {
|
| 158 |
+
background: linear-gradient(
|
| 159 |
+
90deg,
|
| 160 |
+
rgba(139, 92, 246, 0.15) 0%,
|
| 161 |
+
rgba(139, 92, 246, 0.3) 50%,
|
| 162 |
+
rgba(139, 92, 246, 0.15) 100%
|
| 163 |
+
);
|
| 164 |
+
background-size: 200% 100%;
|
| 165 |
+
animation: shimmer 2s linear infinite;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
/* βββ Light Mode Global Overrides βββββββββββββββββββββββββββββββ */
|
| 169 |
+
|
| 170 |
+
[data-theme="light"] .text-white { color: var(--text-primary) !important; }
|
| 171 |
+
[data-theme="light"] .text-white\/70 { color: var(--text-secondary) !important; }
|
| 172 |
+
[data-theme="light"] .text-gray-200 { color: #1f2937 !important; }
|
| 173 |
+
[data-theme="light"] .text-gray-300 { color: #374151 !important; }
|
| 174 |
+
[data-theme="light"] .text-gray-400 { color: #4b5563 !important; }
|
| 175 |
+
[data-theme="light"] .text-gray-500 { color: #6b7280 !important; }
|
| 176 |
+
[data-theme="light"] .text-gray-600 { color: #9ca3af !important; }
|
| 177 |
+
[data-theme="light"] .text-violet-300 { color: #7c3aed !important; }
|
| 178 |
+
[data-theme="light"] .text-violet-400 { color: #7c3aed !important; }
|
| 179 |
+
[data-theme="light"] .text-green-400 { color: #15803d !important; }
|
| 180 |
+
[data-theme="light"] .text-red-400 { color: #b91c1c !important; }
|
| 181 |
+
[data-theme="light"] pre {
|
| 182 |
+
background-color: var(--bg-tertiary) !important;
|
| 183 |
+
color: #374151 !important;
|
| 184 |
+
}
|
| 185 |
+
[data-theme="light"] .recharts-cartesian-grid line {
|
| 186 |
+
stroke: rgba(0, 0, 0, 0.06) !important;
|
| 187 |
+
}
|
frontend/src/lib/api.ts
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { InitResponse, RLState, SchemaGraph, SSEEvent } from './types'
|
| 2 |
+
|
| 3 |
+
const BASE_URL: string = import.meta.env.VITE_API_URL ?? ''
|
| 4 |
+
|
| 5 |
+
async function* parseSSE(response: Response): AsyncGenerator<SSEEvent> {
|
| 6 |
+
const reader = response.body!.getReader()
|
| 7 |
+
const decoder = new TextDecoder()
|
| 8 |
+
let buffer = ''
|
| 9 |
+
|
| 10 |
+
while (true) {
|
| 11 |
+
const { done, value } = await reader.read()
|
| 12 |
+
if (done) break
|
| 13 |
+
buffer += decoder.decode(value, { stream: true })
|
| 14 |
+
const lines = buffer.split('\n')
|
| 15 |
+
buffer = lines.pop() ?? ''
|
| 16 |
+
|
| 17 |
+
for (const line of lines) {
|
| 18 |
+
if (!line.startsWith('data: ')) continue
|
| 19 |
+
const raw = line.slice(6).trim()
|
| 20 |
+
if (raw === '[DONE]') return
|
| 21 |
+
try {
|
| 22 |
+
yield JSON.parse(raw) as SSEEvent
|
| 23 |
+
} catch {
|
| 24 |
+
// ignore malformed lines
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
export async function* streamExecuteQuery(
|
| 31 |
+
question: string,
|
| 32 |
+
taskId: string
|
| 33 |
+
): AsyncGenerator<SSEEvent> {
|
| 34 |
+
const res = await fetch(`${BASE_URL}/api/execute-query`, {
|
| 35 |
+
method: 'POST',
|
| 36 |
+
headers: { 'Content-Type': 'application/json' },
|
| 37 |
+
body: JSON.stringify({ question, task_id: taskId }),
|
| 38 |
+
})
|
| 39 |
+
if (!res.ok) {
|
| 40 |
+
throw new Error(`HTTP ${res.status}: ${res.statusText}`)
|
| 41 |
+
}
|
| 42 |
+
yield* parseSSE(res)
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
export async function* streamBenchmark(
|
| 46 |
+
taskId: string,
|
| 47 |
+
queryIds?: string[]
|
| 48 |
+
): AsyncGenerator<SSEEvent> {
|
| 49 |
+
const body: Record<string, unknown> = { task_id: taskId }
|
| 50 |
+
if (queryIds) body.queryIds = queryIds
|
| 51 |
+
|
| 52 |
+
const res = await fetch(`${BASE_URL}/api/benchmark`, {
|
| 53 |
+
method: 'POST',
|
| 54 |
+
headers: { 'Content-Type': 'application/json' },
|
| 55 |
+
body: JSON.stringify(body),
|
| 56 |
+
})
|
| 57 |
+
if (!res.ok) {
|
| 58 |
+
throw new Error(`HTTP ${res.status}: ${res.statusText}`)
|
| 59 |
+
}
|
| 60 |
+
yield* parseSSE(res)
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
export async function fetchInit(): Promise<InitResponse> {
|
| 64 |
+
const res = await fetch(`${BASE_URL}/api/init`)
|
| 65 |
+
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
| 66 |
+
return res.json() as Promise<InitResponse>
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
export async function fetchRLState(): Promise<RLState> {
|
| 70 |
+
const res = await fetch(`${BASE_URL}/api/rl-state`)
|
| 71 |
+
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
| 72 |
+
return res.json() as Promise<RLState>
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
export async function fetchSchemaGraph(): Promise<SchemaGraph> {
|
| 76 |
+
const res = await fetch(`${BASE_URL}/api/schema-graph`)
|
| 77 |
+
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
| 78 |
+
return res.json() as Promise<SchemaGraph>
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
export async function submitFeedback(
|
| 82 |
+
question: string,
|
| 83 |
+
sql: string,
|
| 84 |
+
correct: boolean
|
| 85 |
+
): Promise<void> {
|
| 86 |
+
await fetch(`${BASE_URL}/api/feedback`, {
|
| 87 |
+
method: 'POST',
|
| 88 |
+
headers: { 'Content-Type': 'application/json' },
|
| 89 |
+
body: JSON.stringify({ question, sql, correct }),
|
| 90 |
+
})
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
export async function fetchPromptHistory() {
|
| 94 |
+
const res = await fetch(`${BASE_URL}/api/prompt-history`)
|
| 95 |
+
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
| 96 |
+
return res.json()
|
| 97 |
+
}
|
frontend/src/lib/types.ts
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// βββ Chat Types ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 2 |
+
|
| 3 |
+
export type MessageStatus = 'streaming' | 'done' | 'error'
|
| 4 |
+
export type FeedbackType = 'correct' | 'wrong' | null
|
| 5 |
+
|
| 6 |
+
export interface AttemptStep {
|
| 7 |
+
attempt: number
|
| 8 |
+
sql: string
|
| 9 |
+
error?: string
|
| 10 |
+
action?: string
|
| 11 |
+
actionScore?: number
|
| 12 |
+
reward?: number
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
export interface ChatMessage {
|
| 16 |
+
id: string
|
| 17 |
+
question: string
|
| 18 |
+
status: MessageStatus
|
| 19 |
+
sql: string
|
| 20 |
+
rows: Record<string, unknown>[]
|
| 21 |
+
rowCount: number
|
| 22 |
+
errorMsg?: string
|
| 23 |
+
attempts: number
|
| 24 |
+
steps: AttemptStep[]
|
| 25 |
+
reward?: number
|
| 26 |
+
rlAction?: string
|
| 27 |
+
rlActionScore?: number
|
| 28 |
+
feedback: FeedbackType
|
| 29 |
+
feedbackSending?: boolean
|
| 30 |
+
promptGeneration: number
|
| 31 |
+
streamingCursor?: boolean
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// βββ Benchmark Types βββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
|
| 36 |
+
export type BenchmarkStatus = 'pending' | 'running' | 'pass' | 'fail'
|
| 37 |
+
export type Difficulty = 'easy' | 'medium' | 'hard'
|
| 38 |
+
|
| 39 |
+
export interface BenchmarkQuery {
|
| 40 |
+
id: string
|
| 41 |
+
question: string
|
| 42 |
+
difficulty: Difficulty
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
export interface BenchmarkResult {
|
| 46 |
+
id: string
|
| 47 |
+
question: string
|
| 48 |
+
difficulty: Difficulty
|
| 49 |
+
status: BenchmarkStatus
|
| 50 |
+
score: number | null
|
| 51 |
+
sql: string | null
|
| 52 |
+
reason: string | null
|
| 53 |
+
attempts: number | null
|
| 54 |
+
refRowCount: number | null
|
| 55 |
+
agentRowCount: number | null
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// βββ RL State ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 59 |
+
|
| 60 |
+
export interface RLEpisode {
|
| 61 |
+
episode: number
|
| 62 |
+
totalReward: number
|
| 63 |
+
successRate: number
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
export interface ActionCount {
|
| 67 |
+
action: string
|
| 68 |
+
count: number
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
export interface RLState {
|
| 72 |
+
totalEpisodes: number
|
| 73 |
+
successRate: number
|
| 74 |
+
currentAlpha: number
|
| 75 |
+
episodes: RLEpisode[]
|
| 76 |
+
actionDistribution: ActionCount[]
|
| 77 |
+
currentGeneration: number
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// βββ GEPA / Prompt βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 81 |
+
|
| 82 |
+
export interface PromptSnapshot {
|
| 83 |
+
generation: number
|
| 84 |
+
prompt: string
|
| 85 |
+
score: number
|
| 86 |
+
summary: string
|
| 87 |
+
timestamp: string
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
// βββ Schema ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 91 |
+
|
| 92 |
+
export interface TableInfo {
|
| 93 |
+
name: string
|
| 94 |
+
rows: number
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
export interface ColumnInfo {
|
| 98 |
+
name: string
|
| 99 |
+
type: string
|
| 100 |
+
pk?: boolean
|
| 101 |
+
fk?: string
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
export interface SchemaTable {
|
| 105 |
+
name: string
|
| 106 |
+
columns: ColumnInfo[]
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
export interface SchemaRelationship {
|
| 110 |
+
from: string
|
| 111 |
+
fromCol: string
|
| 112 |
+
to: string
|
| 113 |
+
toCol: string
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
export interface SchemaGraph {
|
| 117 |
+
tables: SchemaTable[]
|
| 118 |
+
relationships: SchemaRelationship[]
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
// βββ API Response Types ββββββββββββββββββββββββββββββββββββββββββ
|
| 122 |
+
|
| 123 |
+
export interface InitResponse {
|
| 124 |
+
seeded: boolean
|
| 125 |
+
tables: TableInfo[]
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
export interface SSEEvent {
|
| 129 |
+
type: string
|
| 130 |
+
[key: string]: unknown
|
| 131 |
+
}
|
frontend/src/main.tsx
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from 'react'
|
| 2 |
+
import ReactDOM from 'react-dom/client'
|
| 3 |
+
import App from './App'
|
| 4 |
+
import './index.css'
|
| 5 |
+
|
| 6 |
+
// Restore persisted theme
|
| 7 |
+
try {
|
| 8 |
+
const saved = localStorage.getItem('theme') as 'dark' | 'light' | null
|
| 9 |
+
if (saved) document.documentElement.setAttribute('data-theme', saved)
|
| 10 |
+
else document.documentElement.setAttribute('data-theme', 'dark')
|
| 11 |
+
} catch {
|
| 12 |
+
document.documentElement.setAttribute('data-theme', 'dark')
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
ReactDOM.createRoot(document.getElementById('root')!).render(
|
| 16 |
+
<React.StrictMode>
|
| 17 |
+
<App />
|
| 18 |
+
</React.StrictMode>
|
| 19 |
+
)
|
frontend/src/store/useStore.ts
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { create } from 'zustand'
|
| 2 |
+
import type {
|
| 3 |
+
ChatMessage,
|
| 4 |
+
BenchmarkResult,
|
| 5 |
+
RLState,
|
| 6 |
+
TableInfo,
|
| 7 |
+
SchemaGraph,
|
| 8 |
+
PromptSnapshot,
|
| 9 |
+
Difficulty,
|
| 10 |
+
} from '../lib/types'
|
| 11 |
+
|
| 12 |
+
interface Store {
|
| 13 |
+
// Theme
|
| 14 |
+
theme: 'dark' | 'light'
|
| 15 |
+
toggleTheme: () => void
|
| 16 |
+
|
| 17 |
+
// Task
|
| 18 |
+
taskId: string
|
| 19 |
+
taskDifficulty: Difficulty
|
| 20 |
+
setTaskId: (id: string) => void
|
| 21 |
+
setTaskDifficulty: (d: Difficulty) => void
|
| 22 |
+
|
| 23 |
+
// Init / DB
|
| 24 |
+
dbSeeded: boolean
|
| 25 |
+
setDbSeeded: (v: boolean) => void
|
| 26 |
+
tables: TableInfo[]
|
| 27 |
+
setTables: (tables: TableInfo[]) => void
|
| 28 |
+
schemaGraph: SchemaGraph | null
|
| 29 |
+
setSchemaGraph: (g: SchemaGraph) => void
|
| 30 |
+
|
| 31 |
+
// Chat
|
| 32 |
+
messages: ChatMessage[]
|
| 33 |
+
addMessage: (msg: ChatMessage) => void
|
| 34 |
+
updateMessage: (id: string, update: Partial<ChatMessage>) => void
|
| 35 |
+
clearMessages: () => void
|
| 36 |
+
isExecuting: boolean
|
| 37 |
+
setIsExecuting: (v: boolean) => void
|
| 38 |
+
optimizingBanner: boolean
|
| 39 |
+
setOptimizingBanner: (v: boolean) => void
|
| 40 |
+
|
| 41 |
+
// Benchmark
|
| 42 |
+
benchmarkResults: BenchmarkResult[]
|
| 43 |
+
setBenchmarkResults: (r: BenchmarkResult[]) => void
|
| 44 |
+
updateBenchmarkResult: (r: BenchmarkResult) => void
|
| 45 |
+
resetBenchmark: () => void
|
| 46 |
+
isBenchmarking: boolean
|
| 47 |
+
setIsBenchmarking: (v: boolean) => void
|
| 48 |
+
activeBenchmarkId: string | null
|
| 49 |
+
setActiveBenchmarkId: (id: string | null) => void
|
| 50 |
+
overallScore: number | null
|
| 51 |
+
setOverallScore: (s: number) => void
|
| 52 |
+
|
| 53 |
+
// RL State
|
| 54 |
+
rlState: RLState | null
|
| 55 |
+
setRlState: (s: RLState) => void
|
| 56 |
+
|
| 57 |
+
// GEPA / Prompt
|
| 58 |
+
currentPrompt: string
|
| 59 |
+
promptGeneration: number
|
| 60 |
+
promptHistory: PromptSnapshot[]
|
| 61 |
+
setPromptData: (data: { prompt: string; generation: number; history: PromptSnapshot[] }) => void
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
const EASY_QUERIES: BenchmarkResult[] = [
|
| 65 |
+
{ id: 'E1', question: 'Show all products', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 66 |
+
{ id: 'E2', question: 'List all users from the USA', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 67 |
+
{ id: 'E3', question: 'What product categories exist?', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 68 |
+
{ id: 'E4', question: 'How many orders are in the database?', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 69 |
+
{ id: 'E5', question: 'Show all sellers with their names', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
const MEDIUM_QUERIES: BenchmarkResult[] = [
|
| 73 |
+
{ id: 'M1', question: 'Top 5 sellers by total revenue', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 74 |
+
{ id: 'M2', question: 'Average order value by country', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 75 |
+
{ id: 'M3', question: 'Products with stock below 10 units', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 76 |
+
{ id: 'M4', question: 'Monthly order count for the last 12 months', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 77 |
+
{ id: 'M5', question: 'Categories ranked by number of products', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
const HARD_QUERIES: BenchmarkResult[] = [
|
| 81 |
+
{ id: 'H1', question: 'Rolling 7-day revenue for the past 30 days', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 82 |
+
{ id: 'H2', question: 'Seller ranking with rank change from previous month', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 83 |
+
{ id: 'H3', question: 'Cohort retention analysis by signup month', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 84 |
+
{ id: 'H4', question: 'Identify top products contributing to 80% of revenue (Pareto)', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 85 |
+
{ id: 'H5', question: 'Customer lifetime value segmented by acquisition channel', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
export const useStore = create<Store>((set) => ({
|
| 89 |
+
// Theme
|
| 90 |
+
theme: 'dark',
|
| 91 |
+
toggleTheme: () =>
|
| 92 |
+
set((s) => {
|
| 93 |
+
const next = s.theme === 'dark' ? 'light' : 'dark'
|
| 94 |
+
document.documentElement.setAttribute('data-theme', next)
|
| 95 |
+
try { localStorage.setItem('theme', next) } catch { /* noop */ }
|
| 96 |
+
return { theme: next }
|
| 97 |
+
}),
|
| 98 |
+
|
| 99 |
+
// Task
|
| 100 |
+
taskId: 'easy',
|
| 101 |
+
taskDifficulty: 'easy',
|
| 102 |
+
setTaskId: (id) => set({ taskId: id }),
|
| 103 |
+
setTaskDifficulty: (d) =>
|
| 104 |
+
set({
|
| 105 |
+
taskDifficulty: d,
|
| 106 |
+
taskId: d,
|
| 107 |
+
benchmarkResults:
|
| 108 |
+
d === 'easy' ? EASY_QUERIES : d === 'medium' ? MEDIUM_QUERIES : HARD_QUERIES,
|
| 109 |
+
overallScore: null,
|
| 110 |
+
}),
|
| 111 |
+
|
| 112 |
+
// Init
|
| 113 |
+
dbSeeded: false,
|
| 114 |
+
setDbSeeded: (v) => set({ dbSeeded: v }),
|
| 115 |
+
tables: [],
|
| 116 |
+
setTables: (tables) => set({ tables }),
|
| 117 |
+
schemaGraph: null,
|
| 118 |
+
setSchemaGraph: (g) => set({ schemaGraph: g }),
|
| 119 |
+
|
| 120 |
+
// Chat
|
| 121 |
+
messages: [],
|
| 122 |
+
addMessage: (msg) => set((s) => ({ messages: [...s.messages, msg] })),
|
| 123 |
+
updateMessage: (id, update) =>
|
| 124 |
+
set((s) => ({
|
| 125 |
+
messages: s.messages.map((m) => (m.id === id ? { ...m, ...update } : m)),
|
| 126 |
+
})),
|
| 127 |
+
clearMessages: () => set({ messages: [] }),
|
| 128 |
+
isExecuting: false,
|
| 129 |
+
setIsExecuting: (v) => set({ isExecuting: v }),
|
| 130 |
+
optimizingBanner: false,
|
| 131 |
+
setOptimizingBanner: (v) => set({ optimizingBanner: v }),
|
| 132 |
+
|
| 133 |
+
// Benchmark
|
| 134 |
+
benchmarkResults: EASY_QUERIES,
|
| 135 |
+
setBenchmarkResults: (r) => set({ benchmarkResults: r }),
|
| 136 |
+
updateBenchmarkResult: (r) =>
|
| 137 |
+
set((s) => ({
|
| 138 |
+
benchmarkResults: s.benchmarkResults.map((br) => (br.id === r.id ? r : br)),
|
| 139 |
+
})),
|
| 140 |
+
resetBenchmark: () =>
|
| 141 |
+
set((s) => ({
|
| 142 |
+
benchmarkResults: s.benchmarkResults.map((r) => ({
|
| 143 |
+
...r,
|
| 144 |
+
status: 'pending' as const,
|
| 145 |
+
score: null,
|
| 146 |
+
sql: null,
|
| 147 |
+
reason: null,
|
| 148 |
+
attempts: null,
|
| 149 |
+
refRowCount: null,
|
| 150 |
+
agentRowCount: null,
|
| 151 |
+
})),
|
| 152 |
+
overallScore: null,
|
| 153 |
+
})),
|
| 154 |
+
isBenchmarking: false,
|
| 155 |
+
setIsBenchmarking: (v) => set({ isBenchmarking: v }),
|
| 156 |
+
activeBenchmarkId: null,
|
| 157 |
+
setActiveBenchmarkId: (id) => set({ activeBenchmarkId: id }),
|
| 158 |
+
overallScore: null,
|
| 159 |
+
setOverallScore: (s) => set({ overallScore: s }),
|
| 160 |
+
|
| 161 |
+
// RL State
|
| 162 |
+
rlState: null,
|
| 163 |
+
setRlState: (s) => set({ rlState: s }),
|
| 164 |
+
|
| 165 |
+
// GEPA
|
| 166 |
+
currentPrompt: '',
|
| 167 |
+
promptGeneration: 0,
|
| 168 |
+
promptHistory: [],
|
| 169 |
+
setPromptData: (data) =>
|
| 170 |
+
set({
|
| 171 |
+
currentPrompt: data.prompt,
|
| 172 |
+
promptGeneration: data.generation,
|
| 173 |
+
promptHistory: data.history,
|
| 174 |
+
}),
|
| 175 |
+
}))
|
frontend/src/vite-env.d.ts
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/// <reference types="vite/client" />
|
| 2 |
+
|
| 3 |
+
interface ImportMetaEnv {
|
| 4 |
+
readonly VITE_API_URL?: string
|
| 5 |
+
}
|
| 6 |
+
|
| 7 |
+
interface ImportMeta {
|
| 8 |
+
readonly env: ImportMetaEnv
|
| 9 |
+
}
|
frontend/tailwind.config.js
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/** @type {import('tailwindcss').Config} */
|
| 2 |
+
export default {
|
| 3 |
+
content: [
|
| 4 |
+
'./index.html',
|
| 5 |
+
'./src/**/*.{js,ts,jsx,tsx}',
|
| 6 |
+
],
|
| 7 |
+
theme: {
|
| 8 |
+
extend: {
|
| 9 |
+
colors: {
|
| 10 |
+
'bg-primary': '#08080d',
|
| 11 |
+
'bg-secondary': '#09090f',
|
| 12 |
+
'bg-card': '#0e0e16',
|
| 13 |
+
},
|
| 14 |
+
fontFamily: {
|
| 15 |
+
mono: ['ui-monospace', '"SF Mono"', 'Consolas', '"Liberation Mono"', 'monospace'],
|
| 16 |
+
},
|
| 17 |
+
},
|
| 18 |
+
},
|
| 19 |
+
plugins: [],
|
| 20 |
+
}
|
frontend/tsconfig.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compilerOptions": {
|
| 3 |
+
"target": "ES2020",
|
| 4 |
+
"useDefineForClassFields": true,
|
| 5 |
+
"lib": ["ES2020", "DOM", "DOM.Iterable"],
|
| 6 |
+
"module": "ESNext",
|
| 7 |
+
"skipLibCheck": true,
|
| 8 |
+
"moduleResolution": "bundler",
|
| 9 |
+
"allowImportingTsExtensions": true,
|
| 10 |
+
"resolveJsonModule": true,
|
| 11 |
+
"isolatedModules": true,
|
| 12 |
+
"noEmit": true,
|
| 13 |
+
"jsx": "react-jsx",
|
| 14 |
+
"strict": true,
|
| 15 |
+
"noUnusedLocals": true,
|
| 16 |
+
"noUnusedParameters": true,
|
| 17 |
+
"noFallthroughCasesInSwitch": true,
|
| 18 |
+
"baseUrl": ".",
|
| 19 |
+
"paths": {
|
| 20 |
+
"@/*": ["./src/*"]
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"include": ["src"]
|
| 24 |
+
}
|
frontend/vite.config.ts
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { defineConfig } from 'vite'
|
| 2 |
+
import react from '@vitejs/plugin-react'
|
| 3 |
+
import path from 'path'
|
| 4 |
+
|
| 5 |
+
export default defineConfig({
|
| 6 |
+
plugins: [react()],
|
| 7 |
+
resolve: {
|
| 8 |
+
alias: {
|
| 9 |
+
'@': path.resolve(__dirname, './src'),
|
| 10 |
+
},
|
| 11 |
+
},
|
| 12 |
+
server: {
|
| 13 |
+
port: 5173,
|
| 14 |
+
proxy: {
|
| 15 |
+
'/api': {
|
| 16 |
+
target: 'http://localhost:8000',
|
| 17 |
+
changeOrigin: true,
|
| 18 |
+
},
|
| 19 |
+
'/env': {
|
| 20 |
+
target: 'http://localhost:8000',
|
| 21 |
+
changeOrigin: true,
|
| 22 |
+
},
|
| 23 |
+
},
|
| 24 |
+
},
|
| 25 |
+
build: {
|
| 26 |
+
outDir: 'dist',
|
| 27 |
+
},
|
| 28 |
+
})
|
inference.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQL Agent OpenEnv β Baseline Inference Script
|
| 3 |
+
==============================================
|
| 4 |
+
|
| 5 |
+
Runs a baseline LLM agent against all 3 tasks of the SQL Agent OpenEnv environment.
|
| 6 |
+
|
| 7 |
+
Environment variables (required):
|
| 8 |
+
API_BASE_URL β OpenAI-compatible base URL (default: https://router.huggingface.co/v1)
|
| 9 |
+
MODEL_NAME β Model identifier (default: Qwen/Qwen2.5-72B-Instruct)
|
| 10 |
+
HF_TOKEN β Hugging Face / API key
|
| 11 |
+
|
| 12 |
+
STDOUT format (strictly enforced):
|
| 13 |
+
[START] task=<task_id> env=sql-agent-openenv model=<model>
|
| 14 |
+
[STEP] step=<n> action=<action> reward=<0.00> done=<true|false> error=<msg|null>
|
| 15 |
+
[END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import asyncio
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import textwrap
|
| 24 |
+
from typing import List, Optional
|
| 25 |
+
|
| 26 |
+
# ββ Path setup (inference.py lives at repo root; backend is a subdirectory) ββ
|
| 27 |
+
_BACKEND = os.path.join(os.path.dirname(os.path.abspath(__file__)), "backend")
|
| 28 |
+
if _BACKEND not in sys.path:
|
| 29 |
+
sys.path.insert(0, _BACKEND)
|
| 30 |
+
|
| 31 |
+
from openai import OpenAI # noqa: E402
|
| 32 |
+
|
| 33 |
+
from env.sql_env import SQLAgentEnv, Action, Observation # noqa: E402
|
| 34 |
+
|
| 35 |
+
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
|
| 37 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY", "")
|
| 38 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 39 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 40 |
+
BENCHMARK = "sql-agent-openenv"
|
| 41 |
+
|
| 42 |
+
TASKS = ["simple_queries", "join_queries", "complex_queries"]
|
| 43 |
+
MAX_STEPS = 5
|
| 44 |
+
TEMPERATURE = 0.2
|
| 45 |
+
MAX_TOKENS = 50
|
| 46 |
+
|
| 47 |
+
REPAIR_ACTIONS = [
|
| 48 |
+
"rewrite_full",
|
| 49 |
+
"fix_column",
|
| 50 |
+
"fix_table",
|
| 51 |
+
"add_groupby",
|
| 52 |
+
"rewrite_cte",
|
| 53 |
+
"fix_syntax",
|
| 54 |
+
"change_dialect",
|
| 55 |
+
"relax_filter",
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
SYSTEM_PROMPT = textwrap.dedent("""
|
| 59 |
+
You are an expert SQL agent interacting with a SQL repair environment.
|
| 60 |
+
|
| 61 |
+
At each step you receive a natural language question, a database schema,
|
| 62 |
+
and optionally the last SQL attempt + error message.
|
| 63 |
+
|
| 64 |
+
Your job: pick ONE repair action from the list below that is most likely
|
| 65 |
+
to fix the SQL error on the next attempt.
|
| 66 |
+
|
| 67 |
+
Available actions:
|
| 68 |
+
generate β write fresh SQL from scratch (use on first attempt)
|
| 69 |
+
rewrite_full β completely rewrite the query from scratch
|
| 70 |
+
fix_column β fix wrong column name references
|
| 71 |
+
fix_table β fix wrong table name references
|
| 72 |
+
add_groupby β add or fix GROUP BY / aggregation clauses
|
| 73 |
+
rewrite_cte β restructure subqueries or CTEs
|
| 74 |
+
fix_syntax β fix syntax errors (brackets, commas, keywords)
|
| 75 |
+
change_dialect β convert to SQLite-compatible functions
|
| 76 |
+
relax_filter β broaden or remove overly strict WHERE conditions
|
| 77 |
+
|
| 78 |
+
Reply with ONLY the action name. No explanation. No punctuation.
|
| 79 |
+
Example: fix_column
|
| 80 |
+
""").strip()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ββ Logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 84 |
+
|
| 85 |
+
def log_start(task: str, model: str) -> None:
|
| 86 |
+
print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 90 |
+
error_val = error.replace("\n", " ").strip() if error else "null"
|
| 91 |
+
done_val = str(done).lower()
|
| 92 |
+
print(
|
| 93 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} "
|
| 94 |
+
f"done={done_val} error={error_val}",
|
| 95 |
+
flush=True,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 100 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 101 |
+
print(
|
| 102 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 103 |
+
f"score={score:.3f} rewards={rewards_str}",
|
| 104 |
+
flush=True,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ββ LLM helper ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
+
|
| 110 |
+
def pick_action(
|
| 111 |
+
client: OpenAI,
|
| 112 |
+
obs: Observation,
|
| 113 |
+
step: int,
|
| 114 |
+
) -> str:
|
| 115 |
+
"""Ask the LLM to pick a repair action given the current observation."""
|
| 116 |
+
if step == 1 or obs.current_sql is None:
|
| 117 |
+
return "generate"
|
| 118 |
+
|
| 119 |
+
user_msg = textwrap.dedent(f"""
|
| 120 |
+
Question: {obs.question}
|
| 121 |
+
|
| 122 |
+
Current SQL (failed):
|
| 123 |
+
{obs.current_sql}
|
| 124 |
+
|
| 125 |
+
Error: {obs.error_message or "unknown"}
|
| 126 |
+
Error class: {obs.error_class or "unknown"}
|
| 127 |
+
Attempt number: {obs.attempt_number} of {obs.max_attempts}
|
| 128 |
+
|
| 129 |
+
Which repair action should I use next?
|
| 130 |
+
""").strip()
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
completion = client.chat.completions.create(
|
| 134 |
+
model=MODEL_NAME,
|
| 135 |
+
messages=[
|
| 136 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 137 |
+
{"role": "user", "content": user_msg},
|
| 138 |
+
],
|
| 139 |
+
temperature=TEMPERATURE,
|
| 140 |
+
max_tokens=MAX_TOKENS,
|
| 141 |
+
)
|
| 142 |
+
raw = (completion.choices[0].message.content or "").strip().lower()
|
| 143 |
+
# Normalise to valid action name
|
| 144 |
+
for action in REPAIR_ACTIONS:
|
| 145 |
+
if action in raw:
|
| 146 |
+
return action
|
| 147 |
+
return "rewrite_full"
|
| 148 |
+
except Exception as exc:
|
| 149 |
+
print(f"[DEBUG] LLM call failed: {exc}", flush=True)
|
| 150 |
+
return "rewrite_full"
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ββ Single-episode runner βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 154 |
+
|
| 155 |
+
async def run_episode(
|
| 156 |
+
env: SQLAgentEnv,
|
| 157 |
+
client: OpenAI,
|
| 158 |
+
task_id: str,
|
| 159 |
+
) -> None:
|
| 160 |
+
"""Run one full episode for a task, emitting structured stdout logs."""
|
| 161 |
+
log_start(task=task_id, model=MODEL_NAME)
|
| 162 |
+
|
| 163 |
+
rewards: List[float] = []
|
| 164 |
+
steps_taken = 0
|
| 165 |
+
score = 0.0
|
| 166 |
+
success = False
|
| 167 |
+
last_error: Optional[str] = None
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
obs = env.reset(task_id)
|
| 171 |
+
|
| 172 |
+
for step in range(1, MAX_STEPS + 1):
|
| 173 |
+
action_name = pick_action(client, obs, step)
|
| 174 |
+
action = Action(repair_action=action_name)
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
obs, reward_info = await env.step(action)
|
| 178 |
+
except RuntimeError as exc:
|
| 179 |
+
log_step(step=step, action=action_name, reward=0.0, done=True, error=str(exc))
|
| 180 |
+
rewards.append(0.0)
|
| 181 |
+
steps_taken = step
|
| 182 |
+
break
|
| 183 |
+
|
| 184 |
+
reward = reward_info.value
|
| 185 |
+
done = reward_info.done
|
| 186 |
+
last_error = obs.error_message
|
| 187 |
+
success = reward_info.success
|
| 188 |
+
|
| 189 |
+
rewards.append(reward)
|
| 190 |
+
steps_taken = step
|
| 191 |
+
|
| 192 |
+
log_step(
|
| 193 |
+
step=step,
|
| 194 |
+
action=action_name,
|
| 195 |
+
reward=reward,
|
| 196 |
+
done=done,
|
| 197 |
+
error=last_error,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
if done:
|
| 201 |
+
break
|
| 202 |
+
|
| 203 |
+
# Score: clamp sum of rewards to [0, 1]
|
| 204 |
+
total = sum(rewards)
|
| 205 |
+
max_possible = MAX_STEPS * 1.0 # max reward per step is 1.0
|
| 206 |
+
score = min(max(total / max_possible, 0.0), 1.0)
|
| 207 |
+
|
| 208 |
+
finally:
|
| 209 |
+
log_end(
|
| 210 |
+
success=success,
|
| 211 |
+
steps=steps_taken,
|
| 212 |
+
score=score,
|
| 213 |
+
rewards=rewards,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 218 |
+
|
| 219 |
+
async def main() -> None:
|
| 220 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 221 |
+
env = SQLAgentEnv()
|
| 222 |
+
|
| 223 |
+
for task_id in TASKS:
|
| 224 |
+
await run_episode(env, client, task_id)
|
| 225 |
+
# Small gap between tasks for readability
|
| 226 |
+
print("", flush=True)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
asyncio.run(main())
|
openenv.yaml
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sql-agent-openenv
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
A SQL generation and repair environment where an AI agent learns to write
|
| 5 |
+
correct SQL queries through a self-debugging loop powered by a LinUCB
|
| 6 |
+
contextual bandit and GEPA prompt evolution. Models real-world data analyst
|
| 7 |
+
workflows β querying databases with natural language, handling errors, and
|
| 8 |
+
iteratively improving.
|
| 9 |
+
|
| 10 |
+
author: sql-agent-openenv-team
|
| 11 |
+
tags:
|
| 12 |
+
- openenv
|
| 13 |
+
- sql
|
| 14 |
+
- rl
|
| 15 |
+
- nlp
|
| 16 |
+
- contextual-bandit
|
| 17 |
+
|
| 18 |
+
# ββ Endpoints ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
+
api:
|
| 20 |
+
reset: /reset
|
| 21 |
+
step: /step
|
| 22 |
+
state: /state
|
| 23 |
+
|
| 24 |
+
# ββ Action Space βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 25 |
+
action_space:
|
| 26 |
+
type: discrete
|
| 27 |
+
n: 9
|
| 28 |
+
actions:
|
| 29 |
+
- name: generate
|
| 30 |
+
description: "Generate SQL from scratch (first attempt)"
|
| 31 |
+
- name: rewrite_full
|
| 32 |
+
description: "Completely rewrite the query from scratch"
|
| 33 |
+
- name: fix_column
|
| 34 |
+
description: "Fix wrong column name references using schema"
|
| 35 |
+
- name: fix_table
|
| 36 |
+
description: "Fix wrong table name references or JOIN structure"
|
| 37 |
+
- name: add_groupby
|
| 38 |
+
description: "Add or fix GROUP BY / aggregation clauses"
|
| 39 |
+
- name: rewrite_cte
|
| 40 |
+
description: "Restructure CTEs or subqueries"
|
| 41 |
+
- name: fix_syntax
|
| 42 |
+
description: "Fix syntax errors (brackets, commas, keywords)"
|
| 43 |
+
- name: change_dialect
|
| 44 |
+
description: "Convert to SQLite-compatible functions"
|
| 45 |
+
- name: relax_filter
|
| 46 |
+
description: "Broaden or remove overly strict WHERE conditions"
|
| 47 |
+
|
| 48 |
+
# ββ Observation Space ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
+
observation_space:
|
| 50 |
+
type: dict
|
| 51 |
+
fields:
|
| 52 |
+
- name: question
|
| 53 |
+
type: string
|
| 54 |
+
description: "Natural language question to answer with SQL"
|
| 55 |
+
- name: schema_info
|
| 56 |
+
type: string
|
| 57 |
+
description: "Full database schema (tables, columns, types, FK relationships)"
|
| 58 |
+
- name: current_sql
|
| 59 |
+
type: string
|
| 60 |
+
nullable: true
|
| 61 |
+
description: "The SQL generated on the last attempt (null on first step)"
|
| 62 |
+
- name: error_message
|
| 63 |
+
type: string
|
| 64 |
+
nullable: true
|
| 65 |
+
description: "SQLite error message from the last attempt (null on success)"
|
| 66 |
+
- name: error_class
|
| 67 |
+
type: string
|
| 68 |
+
nullable: true
|
| 69 |
+
description: "Classified error type (e.g. no_such_column, syntax_error)"
|
| 70 |
+
- name: attempt_number
|
| 71 |
+
type: integer
|
| 72 |
+
description: "Current attempt number (0 at reset, increments each step)"
|
| 73 |
+
- name: max_attempts
|
| 74 |
+
type: integer
|
| 75 |
+
description: "Maximum allowed attempts per episode (5)"
|
| 76 |
+
- name: task_id
|
| 77 |
+
type: string
|
| 78 |
+
description: "Active task identifier"
|
| 79 |
+
- name: task_difficulty
|
| 80 |
+
type: string
|
| 81 |
+
description: "Task difficulty level: easy | medium | hard"
|
| 82 |
+
|
| 83 |
+
# ββ Reward βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 84 |
+
reward:
|
| 85 |
+
range: [-1.5, 1.5]
|
| 86 |
+
description: >
|
| 87 |
+
Shaped reward providing partial progress signals throughout the episode.
|
| 88 |
+
Success on attempt N: 1.0 - 0.1*(N-1).
|
| 89 |
+
Failure step: -0.1 - 0.05*N + severity_improvement_bonus + error_class_change_bonus.
|
| 90 |
+
Penalizes infinite loops (consecutive same error) and rewards convergence toward correct SQL.
|
| 91 |
+
|
| 92 |
+
# ββ Tasks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 93 |
+
tasks:
|
| 94 |
+
- id: simple_queries
|
| 95 |
+
name: Simple SQL Queries
|
| 96 |
+
difficulty: easy
|
| 97 |
+
description: >
|
| 98 |
+
Single-table SELECT queries. Agent must retrieve correct rows by applying
|
| 99 |
+
basic filters and projections on the marketplace database.
|
| 100 |
+
question_count: 5
|
| 101 |
+
grader: >
|
| 102 |
+
Checks that required output columns are present and row count falls
|
| 103 |
+
within expected bounds. Attempt penalty not applied.
|
| 104 |
+
|
| 105 |
+
- id: join_queries
|
| 106 |
+
name: SQL Join Queries
|
| 107 |
+
difficulty: medium
|
| 108 |
+
description: >
|
| 109 |
+
Multi-table JOIN queries with GROUP BY and aggregation. Agent must
|
| 110 |
+
correctly join tables and compute aggregates over the marketplace data.
|
| 111 |
+
question_count: 5
|
| 112 |
+
grader: >
|
| 113 |
+
Correct columns + row count score multiplied by (1.0 - 0.1*(attempts-1)).
|
| 114 |
+
Rewards efficient, first-try solutions.
|
| 115 |
+
|
| 116 |
+
- id: complex_queries
|
| 117 |
+
name: Complex SQL Queries
|
| 118 |
+
difficulty: hard
|
| 119 |
+
description: >
|
| 120 |
+
Advanced queries using CTEs, window functions, nested aggregations, and
|
| 121 |
+
multi-level joins. Requires precise SQLite syntax knowledge.
|
| 122 |
+
question_count: 5
|
| 123 |
+
grader: >
|
| 124 |
+
Strict correctness required. Score capped at 0.8 without first-attempt
|
| 125 |
+
bonus. Attempt penalty of 0.1*(attempts-1) applied. Hard tasks genuinely
|
| 126 |
+
challenge frontier models.
|
| 127 |
+
|
| 128 |
+
# ββ Environment Metadata βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 129 |
+
metadata:
|
| 130 |
+
max_steps_per_episode: 5
|
| 131 |
+
database: SQLite (marketplace schema β users, products, orders, reviews, sellers)
|
| 132 |
+
rl_algorithm: LinUCB contextual bandit (feature_dim=20, 8 repair actions)
|
| 133 |
+
prompt_optimizer: GEPA (Generative Evolutionary Prompt Adaptation)
|
| 134 |
+
runtime_estimate_minutes: 5
|
| 135 |
+
compute_requirements:
|
| 136 |
+
vcpu: 2
|
| 137 |
+
memory_gb: 4
|