ar9avg commited on
Commit
8ae8e0b
Β·
1 Parent(s): f0b682f

Fix chat query failures and benchmark ID mismatches

Browse files

- Free-form chat: use error-based success check instead of task grader
(grader compared against wrong expected results, causing all queries to fail)
- Add /api/benchmark-questions endpoint to expose real task question IDs
- Benchmark SSE: rename query_id→id, success→pass, overall_score→overallScore
to match frontend field names
- Add queryIds filtering support to BenchmarkRequest
- Frontend: load benchmark questions from API instead of hardcoded IDs
(E1-E5 didn't match backend sq-01–sq-05)
- Remove duplicate difficulty tabs from BenchmarkPanel

backend/api/demo.py CHANGED
@@ -179,11 +179,9 @@ async def execute_query_stream(req: ExecuteQueryRequest):
179
 
180
  rows, error = execute_query(generated_sql)
181
 
182
- from env.tasks import grade_response
183
- task_score = grade_response(
184
- task_id, question_obj.id, generated_sql, rows, error, attempt
185
- )
186
- attempt_success = task_score >= 0.8
187
 
188
  current_error_class = None
189
  error_class_name = None
@@ -320,10 +318,30 @@ async def execute_query_stream(req: ExecuteQueryRequest):
320
  return EventSourceResponse(event_generator())
321
 
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  # ─── /api/benchmark ───────────────────────────────────────────────
324
 
325
  class BenchmarkRequest(BaseModel):
326
  task_id: str = "simple_queries"
 
327
 
328
 
329
  @router.post("/benchmark")
@@ -333,10 +351,14 @@ async def run_benchmark(req: BenchmarkRequest):
333
  task = get_task(task_id)
334
  scores: list[float] = []
335
 
336
- for question_obj in task.questions:
 
 
 
 
337
  yield {"data": json.dumps({
338
  "type": "query_start",
339
- "query_id": question_obj.id,
340
  "question": question_obj.question,
341
  })}
342
 
@@ -444,17 +466,18 @@ async def run_benchmark(req: BenchmarkRequest):
444
 
445
  yield {"data": json.dumps({
446
  "type": "query_result",
447
- "query_id": question_obj.id,
448
- "success": success,
449
  "score": task_score,
450
  "sql": sql,
451
  "attempts": attempt,
 
452
  })}
453
 
454
  overall_score = sum(scores) / len(scores) if scores else 0.0
455
  yield {"data": json.dumps({
456
  "type": "done",
457
- "overall_score": overall_score,
458
  "task_id": task_id,
459
  })}
460
 
 
179
 
180
  rows, error = execute_query(generated_sql)
181
 
182
+ # For free-form chat, success = no SQL error (not task grader)
183
+ attempt_success = (error is None)
184
+ task_score = 1.0 if attempt_success else 0.0
 
 
185
 
186
  current_error_class = None
187
  error_class_name = None
 
318
  return EventSourceResponse(event_generator())
319
 
320
 
321
+ # ─── /api/benchmark-questions ────────────────────────────────────
322
+
323
+ @router.get("/benchmark-questions")
324
+ async def get_benchmark_questions(task_id: str = "easy"):
325
+ mapped_id = _DIFFICULTY_MAP.get(task_id, task_id)
326
+ task = get_task(mapped_id)
327
+ difficulty_label = task.difficulty # "easy" | "medium" | "hard"
328
+ return {
329
+ "questions": [
330
+ {
331
+ "id": q.id,
332
+ "question": q.question,
333
+ "difficulty": difficulty_label,
334
+ }
335
+ for q in task.questions
336
+ ]
337
+ }
338
+
339
+
340
  # ─── /api/benchmark ───────────────────────────────────────────────
341
 
342
  class BenchmarkRequest(BaseModel):
343
  task_id: str = "simple_queries"
344
+ queryIds: Optional[list[str]] = None
345
 
346
 
347
  @router.post("/benchmark")
 
351
  task = get_task(task_id)
352
  scores: list[float] = []
353
 
354
+ questions = task.questions
355
+ if req.queryIds:
356
+ questions = [q for q in questions if q.id in req.queryIds]
357
+
358
+ for question_obj in questions:
359
  yield {"data": json.dumps({
360
  "type": "query_start",
361
+ "id": question_obj.id,
362
  "question": question_obj.question,
363
  })}
364
 
 
466
 
467
  yield {"data": json.dumps({
468
  "type": "query_result",
469
+ "id": question_obj.id,
470
+ "pass": success,
471
  "score": task_score,
472
  "sql": sql,
473
  "attempts": attempt,
474
+ "reason": "Correct" if success else "Agent exhausted all repair attempts",
475
  })}
476
 
477
  overall_score = sum(scores) / len(scores) if scores else 0.0
478
  yield {"data": json.dumps({
479
  "type": "done",
480
+ "overallScore": overall_score,
481
  "task_id": task_id,
482
  })}
483
 
frontend/src/App.tsx CHANGED
@@ -11,7 +11,7 @@ import { RightSidebar } from './components/RightSidebar'
11
  import { DemoMode } from './components/DemoMode'
12
  import { ConnectDB } from './components/ConnectDB'
13
  import { useStore } from './store/useStore'
14
- import { fetchInit } from './lib/api'
15
 
16
  type Tab = 'chat' | 'benchmark' | 'er'
17
 
@@ -28,7 +28,7 @@ export default function App() {
28
  const [demoOpen, setDemoOpen] = useState(false)
29
  const [connectDbOpen, setConnectDbOpen] = useState(false)
30
 
31
- const { theme, setDbSeeded, setTables, setSchemaGraph, setDbLabel } = useStore()
32
 
33
  // Apply theme on mount / change
34
  useEffect(() => {
@@ -62,6 +62,30 @@ export default function App() {
62
  .catch(() => { /* noop */ })
63
  }, [setDbSeeded, setTables, setSchemaGraph])
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  // Close mobile sidebars on tab change
66
  useEffect(() => {
67
  setLeftOpen(false)
 
11
  import { DemoMode } from './components/DemoMode'
12
  import { ConnectDB } from './components/ConnectDB'
13
  import { useStore } from './store/useStore'
14
+ import { fetchInit, fetchBenchmarkQuestions } from './lib/api'
15
 
16
  type Tab = 'chat' | 'benchmark' | 'er'
17
 
 
28
  const [demoOpen, setDemoOpen] = useState(false)
29
  const [connectDbOpen, setConnectDbOpen] = useState(false)
30
 
31
+ const { theme, setDbSeeded, setTables, setSchemaGraph, setDbLabel, taskDifficulty } = useStore()
32
 
33
  // Apply theme on mount / change
34
  useEffect(() => {
 
62
  .catch(() => { /* noop */ })
63
  }, [setDbSeeded, setTables, setSchemaGraph])
64
 
65
+ // Load benchmark questions from API on mount
66
+ useEffect(() => {
67
+ const { setBenchmarkResults } = useStore.getState()
68
+ fetchBenchmarkQuestions(taskDifficulty)
69
+ .then(({ questions }) => {
70
+ setBenchmarkResults(
71
+ questions.map((q) => ({
72
+ id: q.id,
73
+ question: q.question,
74
+ difficulty: q.difficulty as 'easy' | 'medium' | 'hard',
75
+ status: 'pending' as const,
76
+ score: null,
77
+ sql: null,
78
+ reason: null,
79
+ attempts: null,
80
+ refRowCount: null,
81
+ agentRowCount: null,
82
+ }))
83
+ )
84
+ })
85
+ .catch(() => { /* noop */ })
86
+ // eslint-disable-next-line react-hooks/exhaustive-deps
87
+ }, [])
88
+
89
  // Close mobile sidebars on tab change
90
  useEffect(() => {
91
  setLeftOpen(false)
frontend/src/components/BenchmarkPanel.tsx CHANGED
@@ -356,23 +356,6 @@ export function BenchmarkPanel() {
356
  )}
357
  </div>
358
 
359
- {/* Difficulty tabs */}
360
- <div className="flex items-center gap-1 px-4 py-2 border-b border-white/[0.06] shrink-0">
361
- {DIFFICULTY_TABS.map((tab) => (
362
- <button
363
- key={tab.id}
364
- onClick={() => setTaskDifficulty(tab.id)}
365
- className={`px-3 py-1 rounded-lg text-xs font-medium transition-all ${
366
- taskDifficulty === tab.id
367
- ? 'bg-violet-600/20 text-violet-300 border border-violet-500/30'
368
- : 'text-gray-500 hover:text-gray-300 hover:bg-white/5 border border-transparent'
369
- }`}
370
- >
371
- {tab.label}
372
- </button>
373
- ))}
374
- </div>
375
-
376
  {/* Query list */}
377
  <div className="flex-1 overflow-y-auto">
378
  <div className="p-2 flex flex-col gap-1">
 
356
  )}
357
  </div>
358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  {/* Query list */}
360
  <div className="flex-1 overflow-y-auto">
361
  <div className="p-2 flex flex-col gap-1">
frontend/src/lib/api.ts CHANGED
@@ -96,6 +96,14 @@ export async function fetchPromptHistory() {
96
  return res.json()
97
  }
98
 
 
 
 
 
 
 
 
 
99
  export async function connectExternalDb(path: string): Promise<{ success: boolean; message: string; tables: { name: string; rows: number }[]; dbLabel: string }> {
100
  const res = await fetch(`${BASE_URL}/api/connect-db`, {
101
  method: 'POST',
 
96
  return res.json()
97
  }
98
 
99
+ export async function fetchBenchmarkQuestions(
100
+ taskId: string
101
+ ): Promise<{ questions: { id: string; question: string; difficulty: string }[] }> {
102
+ const res = await fetch(`${BASE_URL}/api/benchmark-questions?task_id=${encodeURIComponent(taskId)}`)
103
+ if (!res.ok) throw new Error(`HTTP ${res.status}`)
104
+ return res.json()
105
+ }
106
+
107
  export async function connectExternalDb(path: string): Promise<{ success: boolean; message: string; tables: { name: string; rows: number }[]; dbLabel: string }> {
108
  const res = await fetch(`${BASE_URL}/api/connect-db`, {
109
  method: 'POST',
frontend/src/store/useStore.ts CHANGED
@@ -8,6 +8,7 @@ import type {
8
  PromptSnapshot,
9
  Difficulty,
10
  } from '../lib/types'
 
11
 
12
  interface Store {
13
  // Theme
@@ -64,28 +65,12 @@ interface Store {
64
  setPromptData: (data: { prompt: string; generation: number; history: PromptSnapshot[] }) => void
65
  }
66
 
67
- const EASY_QUERIES: BenchmarkResult[] = [
68
- { id: 'E1', question: 'Show all products', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
69
- { id: 'E2', question: 'List all users from the USA', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
70
- { id: 'E3', question: 'What product categories exist?', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
71
- { id: 'E4', question: 'How many orders are in the database?', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
72
- { id: 'E5', question: 'Show all sellers with their names', difficulty: 'easy', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
73
- ]
74
-
75
- const MEDIUM_QUERIES: BenchmarkResult[] = [
76
- { id: 'M1', question: 'Top 5 sellers by total revenue', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
77
- { id: 'M2', question: 'Average order value by country', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
78
- { id: 'M3', question: 'Products with stock below 10 units', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
79
- { id: 'M4', question: 'Monthly order count for the last 12 months', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
80
- { id: 'M5', question: 'Categories ranked by number of products', difficulty: 'medium', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
81
- ]
82
 
83
- const HARD_QUERIES: BenchmarkResult[] = [
84
- { id: 'H1', question: 'Rolling 7-day revenue for the past 30 days', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
85
- { id: 'H2', question: 'Seller ranking with rank change from previous month', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
86
- { id: 'H3', question: 'Cohort retention analysis by signup month', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
87
- { id: 'H4', question: 'Identify top products contributing to 80% of revenue (Pareto)', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
88
- { id: 'H5', question: 'Customer lifetime value segmented by acquisition channel', difficulty: 'hard', status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null },
89
  ]
90
 
91
  export const useStore = create<Store>((set) => ({
@@ -105,13 +90,16 @@ export const useStore = create<Store>((set) => ({
105
  setTaskId: (id) => set({ taskId: id }),
106
  setTaskDifficulty: (d) => {
107
  const taskId = d === 'easy' ? 'simple_queries' : d === 'medium' ? 'join_queries' : 'complex_queries'
108
- set({
109
- taskDifficulty: d,
110
- taskId,
111
- benchmarkResults:
112
- d === 'easy' ? EASY_QUERIES : d === 'medium' ? MEDIUM_QUERIES : HARD_QUERIES,
113
- overallScore: null,
114
- })
 
 
 
115
  },
116
  // DB
117
  dbLabel: 'benchmark (built-in)',
@@ -139,7 +127,7 @@ export const useStore = create<Store>((set) => ({
139
  setOptimizingBanner: (v) => set({ optimizingBanner: v }),
140
 
141
  // Benchmark
142
- benchmarkResults: EASY_QUERIES,
143
  setBenchmarkResults: (r) => set({ benchmarkResults: r }),
144
  updateBenchmarkResult: (r) =>
145
  set((s) => ({
 
8
  PromptSnapshot,
9
  Difficulty,
10
  } from '../lib/types'
11
+ import { fetchBenchmarkQuestions } from '../lib/api'
12
 
13
  interface Store {
14
  // Theme
 
65
  setPromptData: (data: { prompt: string; generation: number; history: PromptSnapshot[] }) => void
66
  }
67
 
68
+ function makePending(id: string, question: string, difficulty: Difficulty): BenchmarkResult {
69
+ return { id, question, difficulty, status: 'pending', score: null, sql: null, reason: null, attempts: null, refRowCount: null, agentRowCount: null }
70
+ }
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ const PLACEHOLDER_QUERIES: BenchmarkResult[] = [
73
+ makePending('loading', 'Loading questions…', 'easy'),
 
 
 
 
74
  ]
75
 
76
  export const useStore = create<Store>((set) => ({
 
90
  setTaskId: (id) => set({ taskId: id }),
91
  setTaskDifficulty: (d) => {
92
  const taskId = d === 'easy' ? 'simple_queries' : d === 'medium' ? 'join_queries' : 'complex_queries'
93
+ set({ taskDifficulty: d, taskId, overallScore: null })
94
+ fetchBenchmarkQuestions(d)
95
+ .then(({ questions }) => {
96
+ set({
97
+ benchmarkResults: questions.map((q) =>
98
+ makePending(q.id, q.question, q.difficulty as Difficulty)
99
+ ),
100
+ })
101
+ })
102
+ .catch(() => { /* keep current list on error */ })
103
  },
104
  // DB
105
  dbLabel: 'benchmark (built-in)',
 
127
  setOptimizingBanner: (v) => set({ optimizingBanner: v }),
128
 
129
  // Benchmark
130
+ benchmarkResults: PLACEHOLDER_QUERIES,
131
  setBenchmarkResults: (r) => set({ benchmarkResults: r }),
132
  updateBenchmarkResult: (r) =>
133
  set((s) => ({