Spaces:
Sleeping
Sleeping
Fix chat query failures and benchmark ID mismatches
Browse files- Free-form chat: use error-based success check instead of task grader
(grader compared against wrong expected results, causing all queries to fail)
- Add /api/benchmark-questions endpoint to expose real task question IDs
- Benchmark SSE: rename query_idβid, successβpass, overall_scoreβoverallScore
to match frontend field names
- Add queryIds filtering support to BenchmarkRequest
- Frontend: load benchmark questions from API instead of hardcoded IDs
(E1-E5 didn't match backend sq-01βsq-05)
- Remove duplicate difficulty tabs from BenchmarkPanel
- backend/api/demo.py +33 -10
- frontend/src/App.tsx +26 -2
- frontend/src/components/BenchmarkPanel.tsx +0 -17
- frontend/src/lib/api.ts +8 -0
- frontend/src/store/useStore.ts +17 -29
backend/api/demo.py
CHANGED
|
@@ -179,11 +179,9 @@ async def execute_query_stream(req: ExecuteQueryRequest):
|
|
| 179 |
|
| 180 |
rows, error = execute_query(generated_sql)
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
)
|
| 186 |
-
attempt_success = task_score >= 0.8
|
| 187 |
|
| 188 |
current_error_class = None
|
| 189 |
error_class_name = None
|
|
@@ -320,10 +318,30 @@ async def execute_query_stream(req: ExecuteQueryRequest):
|
|
| 320 |
return EventSourceResponse(event_generator())
|
| 321 |
|
| 322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
# βββ /api/benchmark βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 324 |
|
| 325 |
class BenchmarkRequest(BaseModel):
|
| 326 |
task_id: str = "simple_queries"
|
|
|
|
| 327 |
|
| 328 |
|
| 329 |
@router.post("/benchmark")
|
|
@@ -333,10 +351,14 @@ async def run_benchmark(req: BenchmarkRequest):
|
|
| 333 |
task = get_task(task_id)
|
| 334 |
scores: list[float] = []
|
| 335 |
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
yield {"data": json.dumps({
|
| 338 |
"type": "query_start",
|
| 339 |
-
"
|
| 340 |
"question": question_obj.question,
|
| 341 |
})}
|
| 342 |
|
|
@@ -444,17 +466,18 @@ async def run_benchmark(req: BenchmarkRequest):
|
|
| 444 |
|
| 445 |
yield {"data": json.dumps({
|
| 446 |
"type": "query_result",
|
| 447 |
-
"
|
| 448 |
-
"
|
| 449 |
"score": task_score,
|
| 450 |
"sql": sql,
|
| 451 |
"attempts": attempt,
|
|
|
|
| 452 |
})}
|
| 453 |
|
| 454 |
overall_score = sum(scores) / len(scores) if scores else 0.0
|
| 455 |
yield {"data": json.dumps({
|
| 456 |
"type": "done",
|
| 457 |
-
"
|
| 458 |
"task_id": task_id,
|
| 459 |
})}
|
| 460 |
|
|
|
|
| 179 |
|
| 180 |
rows, error = execute_query(generated_sql)
|
| 181 |
|
| 182 |
+
# For free-form chat, success = no SQL error (not task grader)
|
| 183 |
+
attempt_success = (error is None)
|
| 184 |
+
task_score = 1.0 if attempt_success else 0.0
|
|
|
|
|
|
|
| 185 |
|
| 186 |
current_error_class = None
|
| 187 |
error_class_name = None
|
|
|
|
| 318 |
return EventSourceResponse(event_generator())
|
| 319 |
|
| 320 |
|
| 321 |
+
# βββ /api/benchmark-questions ββββββββββββββββββββββββββββββββββββ
|
| 322 |
+
|
| 323 |
+
@router.get("/benchmark-questions")
|
| 324 |
+
async def get_benchmark_questions(task_id: str = "easy"):
|
| 325 |
+
mapped_id = _DIFFICULTY_MAP.get(task_id, task_id)
|
| 326 |
+
task = get_task(mapped_id)
|
| 327 |
+
difficulty_label = task.difficulty # "easy" | "medium" | "hard"
|
| 328 |
+
return {
|
| 329 |
+
"questions": [
|
| 330 |
+
{
|
| 331 |
+
"id": q.id,
|
| 332 |
+
"question": q.question,
|
| 333 |
+
"difficulty": difficulty_label,
|
| 334 |
+
}
|
| 335 |
+
for q in task.questions
|
| 336 |
+
]
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
|
| 340 |
# βββ /api/benchmark βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 341 |
|
| 342 |
class BenchmarkRequest(BaseModel):
|
| 343 |
task_id: str = "simple_queries"
|
| 344 |
+
queryIds: Optional[list[str]] = None
|
| 345 |
|
| 346 |
|
| 347 |
@router.post("/benchmark")
|
|
|
|
| 351 |
task = get_task(task_id)
|
| 352 |
scores: list[float] = []
|
| 353 |
|
| 354 |
+
questions = task.questions
|
| 355 |
+
if req.queryIds:
|
| 356 |
+
questions = [q for q in questions if q.id in req.queryIds]
|
| 357 |
+
|
| 358 |
+
for question_obj in questions:
|
| 359 |
yield {"data": json.dumps({
|
| 360 |
"type": "query_start",
|
| 361 |
+
"id": question_obj.id,
|
| 362 |
"question": question_obj.question,
|
| 363 |
})}
|
| 364 |
|
|
|
|
| 466 |
|
| 467 |
yield {"data": json.dumps({
|
| 468 |
"type": "query_result",
|
| 469 |
+
"id": question_obj.id,
|
| 470 |
+
"pass": success,
|
| 471 |
"score": task_score,
|
| 472 |
"sql": sql,
|
| 473 |
"attempts": attempt,
|
| 474 |
+
"reason": "Correct" if success else "Agent exhausted all repair attempts",
|
| 475 |
})}
|
| 476 |
|
| 477 |
overall_score = sum(scores) / len(scores) if scores else 0.0
|
| 478 |
yield {"data": json.dumps({
|
| 479 |
"type": "done",
|
| 480 |
+
"overallScore": overall_score,
|
| 481 |
"task_id": task_id,
|
| 482 |
})}
|
| 483 |
|
frontend/src/App.tsx
CHANGED
|
@@ -11,7 +11,7 @@ import { RightSidebar } from './components/RightSidebar'
|
|
| 11 |
import { DemoMode } from './components/DemoMode'
|
| 12 |
import { ConnectDB } from './components/ConnectDB'
|
| 13 |
import { useStore } from './store/useStore'
|
| 14 |
-
import { fetchInit } from './lib/api'
|
| 15 |
|
| 16 |
type Tab = 'chat' | 'benchmark' | 'er'
|
| 17 |
|
|
@@ -28,7 +28,7 @@ export default function App() {
|
|
| 28 |
const [demoOpen, setDemoOpen] = useState(false)
|
| 29 |
const [connectDbOpen, setConnectDbOpen] = useState(false)
|
| 30 |
|
| 31 |
-
const { theme, setDbSeeded, setTables, setSchemaGraph, setDbLabel } = useStore()
|
| 32 |
|
| 33 |
// Apply theme on mount / change
|
| 34 |
useEffect(() => {
|
|
@@ -62,6 +62,30 @@ export default function App() {
|
|
| 62 |
.catch(() => { /* noop */ })
|
| 63 |
}, [setDbSeeded, setTables, setSchemaGraph])
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
// Close mobile sidebars on tab change
|
| 66 |
useEffect(() => {
|
| 67 |
setLeftOpen(false)
|
|
|
|
| 11 |
import { DemoMode } from './components/DemoMode'
|
| 12 |
import { ConnectDB } from './components/ConnectDB'
|
| 13 |
import { useStore } from './store/useStore'
|
| 14 |
+
import { fetchInit, fetchBenchmarkQuestions } from './lib/api'
|
| 15 |
|
| 16 |
type Tab = 'chat' | 'benchmark' | 'er'
|
| 17 |
|
|
|
|
| 28 |
const [demoOpen, setDemoOpen] = useState(false)
|
| 29 |
const [connectDbOpen, setConnectDbOpen] = useState(false)
|
| 30 |
|
| 31 |
+
const { theme, setDbSeeded, setTables, setSchemaGraph, setDbLabel, taskDifficulty } = useStore()
|
| 32 |
|
| 33 |
// Apply theme on mount / change
|
| 34 |
useEffect(() => {
|
|
|
|
| 62 |
.catch(() => { /* noop */ })
|
| 63 |
}, [setDbSeeded, setTables, setSchemaGraph])
|
| 64 |
|
| 65 |
+
// Load benchmark questions from API on mount
|
| 66 |
+
useEffect(() => {
|
| 67 |
+
const { setBenchmarkResults } = useStore.getState()
|
| 68 |
+
fetchBenchmarkQuestions(taskDifficulty)
|
| 69 |
+
.then(({ questions }) => {
|
| 70 |
+
setBenchmarkResults(
|
| 71 |
+
questions.map((q) => ({
|
| 72 |
+
id: q.id,
|
| 73 |
+
question: q.question,
|
| 74 |
+
difficulty: q.difficulty as 'easy' | 'medium' | 'hard',
|
| 75 |
+
status: 'pending' as const,
|
| 76 |
+
score: null,
|
| 77 |
+
sql: null,
|
| 78 |
+
reason: null,
|
| 79 |
+
attempts: null,
|
| 80 |
+
refRowCount: null,
|
| 81 |
+
agentRowCount: null,
|
| 82 |
+
}))
|
| 83 |
+
)
|
| 84 |
+
})
|
| 85 |
+
.catch(() => { /* noop */ })
|
| 86 |
+
// eslint-disable-next-line react-hooks/exhaustive-deps
|
| 87 |
+
}, [])
|
| 88 |
+
|
| 89 |
// Close mobile sidebars on tab change
|
| 90 |
useEffect(() => {
|
| 91 |
setLeftOpen(false)
|
frontend/src/components/BenchmarkPanel.tsx
CHANGED
|
@@ -356,23 +356,6 @@ export function BenchmarkPanel() {
|
|
| 356 |
)}
|
| 357 |
</div>
|
| 358 |
|
| 359 |
-
{/* Difficulty tabs */}
|
| 360 |
-
<div className="flex items-center gap-1 px-4 py-2 border-b border-white/[0.06] shrink-0">
|
| 361 |
-
{DIFFICULTY_TABS.map((tab) => (
|
| 362 |
-
<button
|
| 363 |
-
key={tab.id}
|
| 364 |
-
onClick={() => setTaskDifficulty(tab.id)}
|
| 365 |
-
className={`px-3 py-1 rounded-lg text-xs font-medium transition-all ${
|
| 366 |
-
taskDifficulty === tab.id
|
| 367 |
-
? 'bg-violet-600/20 text-violet-300 border border-violet-500/30'
|
| 368 |
-
: 'text-gray-500 hover:text-gray-300 hover:bg-white/5 border border-transparent'
|
| 369 |
-
}`}
|
| 370 |
-
>
|
| 371 |
-
{tab.label}
|
| 372 |
-
</button>
|
| 373 |
-
))}
|
| 374 |
-
</div>
|
| 375 |
-
|
| 376 |
{/* Query list */}
|
| 377 |
<div className="flex-1 overflow-y-auto">
|
| 378 |
<div className="p-2 flex flex-col gap-1">
|
|
|
|
| 356 |
)}
|
| 357 |
</div>
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
{/* Query list */}
|
| 360 |
<div className="flex-1 overflow-y-auto">
|
| 361 |
<div className="p-2 flex flex-col gap-1">
|
frontend/src/lib/api.ts
CHANGED
|
@@ -96,6 +96,14 @@ export async function fetchPromptHistory() {
|
|
| 96 |
return res.json()
|
| 97 |
}
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
export async function connectExternalDb(path: string): Promise<{ success: boolean; message: string; tables: { name: string; rows: number }[]; dbLabel: string }> {
|
| 100 |
const res = await fetch(`${BASE_URL}/api/connect-db`, {
|
| 101 |
method: 'POST',
|
|
|
|
| 96 |
return res.json()
|
| 97 |
}
|
| 98 |
|
| 99 |
+
export async function fetchBenchmarkQuestions(
|
| 100 |
+
taskId: string
|
| 101 |
+
): Promise<{ questions: { id: string; question: string; difficulty: string }[] }> {
|
| 102 |
+
const res = await fetch(`${BASE_URL}/api/benchmark-questions?task_id=${encodeURIComponent(taskId)}`)
|
| 103 |
+
if (!res.ok) throw new Error(`HTTP ${res.status}`)
|
| 104 |
+
return res.json()
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
export async function connectExternalDb(path: string): Promise<{ success: boolean; message: string; tables: { name: string; rows: number }[]; dbLabel: string }> {
|
| 108 |
const res = await fetch(`${BASE_URL}/api/connect-db`, {
|
| 109 |
method: 'POST',
|
frontend/src/store/useStore.ts
CHANGED
|
@@ -8,6 +8,7 @@ import type {
|
|
| 8 |
PromptSnapshot,
|
| 9 |
Difficulty,
|
| 10 |
} from '../lib/types'
|
|
|
|
| 11 |
|
| 12 |
interface Store {
|
| 13 |
// Theme
|
|
@@ -64,28 +65,12 @@ interface Store {
|
|
| 64 |
setPromptData: (data: { prompt: string; generation: number; history: PromptSnapshot[] }) => void
|
| 65 |
}
|
| 66 |
|
| 67 |
-
|
| 68 |
-
{ id
|
| 69 |
-
|
| 70 |
-
{ id: 'E3', question: 'What product categories exist?', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 71 |
-
{ id: 'E4', question: 'How many orders are in the database?', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 72 |
-
{ id: 'E5', question: 'Show all sellers with their names', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 73 |
-
]
|
| 74 |
-
|
| 75 |
-
const MEDIUM_QUERIES: BenchmarkResult[] = [
|
| 76 |
-
{ id: 'M1', question: 'Top 5 sellers by total revenue', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 77 |
-
{ id: 'M2', question: 'Average order value by country', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 78 |
-
{ id: 'M3', question: 'Products with stock below 10 units', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 79 |
-
{ id: 'M4', question: 'Monthly order count for the last 12 months', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 80 |
-
{ id: 'M5', question: 'Categories ranked by number of products', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 81 |
-
]
|
| 82 |
|
| 83 |
-
const
|
| 84 |
-
|
| 85 |
-
{ id: 'H2', question: 'Seller ranking with rank change from previous month', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 86 |
-
{ id: 'H3', question: 'Cohort retention analysis by signup month', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 87 |
-
{ id: 'H4', question: 'Identify top products contributing to 80% of revenue (Pareto)', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 88 |
-
{ id: 'H5', question: 'Customer lifetime value segmented by acquisition channel', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
|
| 89 |
]
|
| 90 |
|
| 91 |
export const useStore = create<Store>((set) => ({
|
|
@@ -105,13 +90,16 @@ export const useStore = create<Store>((set) => ({
|
|
| 105 |
setTaskId: (id) => set({ taskId: id }),
|
| 106 |
setTaskDifficulty: (d) => {
|
| 107 |
const taskId = d === 'easy' ? 'simple_queries' : d === 'medium' ? 'join_queries' : 'complex_queries'
|
| 108 |
-
set({
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
| 115 |
},
|
| 116 |
// DB
|
| 117 |
dbLabel: 'benchmark (built-in)',
|
|
@@ -139,7 +127,7 @@ export const useStore = create<Store>((set) => ({
|
|
| 139 |
setOptimizingBanner: (v) => set({ optimizingBanner: v }),
|
| 140 |
|
| 141 |
// Benchmark
|
| 142 |
-
benchmarkResults:
|
| 143 |
setBenchmarkResults: (r) => set({ benchmarkResults: r }),
|
| 144 |
updateBenchmarkResult: (r) =>
|
| 145 |
set((s) => ({
|
|
|
|
| 8 |
PromptSnapshot,
|
| 9 |
Difficulty,
|
| 10 |
} from '../lib/types'
|
| 11 |
+
import { fetchBenchmarkQuestions } from '../lib/api'
|
| 12 |
|
| 13 |
interface Store {
|
| 14 |
// Theme
|
|
|
|
| 65 |
setPromptData: (data: { prompt: string; generation: number; history: PromptSnapshot[] }) => void
|
| 66 |
}
|
| 67 |
|
| 68 |
+
function makePending(id: string, question: string, difficulty: Difficulty): BenchmarkResult {
|
| 69 |
+
return { id, question, difficulty, status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null }
|
| 70 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
const PLACEHOLDER_QUERIES: BenchmarkResult[] = [
|
| 73 |
+
makePending('loading', 'Loading questionsβ¦', 'easy'),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
]
|
| 75 |
|
| 76 |
export const useStore = create<Store>((set) => ({
|
|
|
|
| 90 |
setTaskId: (id) => set({ taskId: id }),
|
| 91 |
setTaskDifficulty: (d) => {
|
| 92 |
const taskId = d === 'easy' ? 'simple_queries' : d === 'medium' ? 'join_queries' : 'complex_queries'
|
| 93 |
+
set({ taskDifficulty: d, taskId, overallScore: null })
|
| 94 |
+
fetchBenchmarkQuestions(d)
|
| 95 |
+
.then(({ questions }) => {
|
| 96 |
+
set({
|
| 97 |
+
benchmarkResults: questions.map((q) =>
|
| 98 |
+
makePending(q.id, q.question, q.difficulty as Difficulty)
|
| 99 |
+
),
|
| 100 |
+
})
|
| 101 |
+
})
|
| 102 |
+
.catch(() => { /* keep current list on error */ })
|
| 103 |
},
|
| 104 |
// DB
|
| 105 |
dbLabel: 'benchmark (built-in)',
|
|
|
|
| 127 |
setOptimizingBanner: (v) => set({ optimizingBanner: v }),
|
| 128 |
|
| 129 |
// Benchmark
|
| 130 |
+
benchmarkResults: PLACEHOLDER_QUERIES,
|
| 131 |
setBenchmarkResults: (r) => set({ benchmarkResults: r }),
|
| 132 |
updateBenchmarkResult: (r) =>
|
| 133 |
set((s) => ({
|