shwetangisingh commited on
Commit
c09a7e7
·
1 Parent(s): df78c68

Stop blocking /chat on evals; let the UI poll for them

Browse files
backend/api/main.py CHANGED
@@ -4,10 +4,12 @@ from __future__ import annotations
4
  import json
5
  import logging
6
  import re
 
7
  import time
 
8
  from pathlib import Path
9
 
10
- from fastapi import FastAPI, HTTPException
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi.responses import StreamingResponse
13
  from pydantic import BaseModel, Field
@@ -65,6 +67,33 @@ def _warmup():
65
  # ── In-memory session store (replace with Redis for multi-worker deployments) ──
66
  _sessions: dict[str, dict] = {}
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  # ── Request / response schemas ─────────────────────────────────────────────────
70
 
@@ -324,8 +353,11 @@ def _compute_and_persist_evals(
324
  )
325
  except Exception:
326
  _log.exception("evals scoring failed for run %s", run_id)
 
327
  return None
328
 
 
 
329
  try:
330
  entry = {
331
  "run_id": run_id,
@@ -345,7 +377,7 @@ def _compute_and_persist_evals(
345
 
346
 
347
  @app.post("/chat", response_model=ChatResponse)
348
- def chat(req: ChatRequest):
349
  guard = check_input(req.query)
350
  if not guard["allowed"]:
351
  return ChatResponse(
@@ -375,17 +407,21 @@ def chat(req: ChatRequest):
375
  affect_emotion = (result.get("affect") or {}).get("emotion", "NEUTRAL")
376
  run_id = result.get("run_id")
377
 
378
- eval_scores = _compute_and_persist_evals(
379
- run_id=run_id,
380
- user_id=req.user_id,
381
- turn_id=result["turn_id"],
382
- response=result["selected_response"] or "",
383
- chunks=list(result.get("retrieved_chunks") or []),
384
- latency_log=dict(result.get("latency_log") or {}),
385
- affect=affect_emotion,
386
- gesture_tag=req.gesture_tag,
387
- gaze_bucket=req.gaze_bucket,
388
- )
 
 
 
 
389
 
390
  return ChatResponse(
391
  user_id=req.user_id,
@@ -400,7 +436,7 @@ def chat(req: ChatRequest):
400
  guardrail_passed=result.get("guardrail_passed", True),
401
  run_id=run_id,
402
  turn_id=result["turn_id"],
403
- eval_scores=eval_scores,
404
  )
405
 
406
 
@@ -412,7 +448,6 @@ def chat_stream(req: ChatRequest):
412
  """
413
  guard = check_input(req.query)
414
  if not guard["allowed"]:
415
- # Mirror the non-stream /chat early-exit.
416
  payload = {
417
  "user_id": req.user_id,
418
  "query": req.query,
@@ -463,17 +498,24 @@ def chat_stream(req: ChatRequest):
463
  affect_emotion = (state.get("affect") or {}).get("emotion", "NEUTRAL")
464
  run_id = state.get("run_id")
465
 
466
- eval_scores = _compute_and_persist_evals(
467
- run_id=run_id,
468
- user_id=req.user_id,
469
- turn_id=state["turn_id"],
470
- response=state["selected_response"] or "",
471
- chunks=list(state.get("retrieved_chunks") or []),
472
- latency_log=dict(state.get("latency_log") or {}),
473
- affect=affect_emotion,
474
- gesture_tag=req.gesture_tag,
475
- gaze_bucket=req.gaze_bucket,
476
- )
 
 
 
 
 
 
 
477
 
478
  final = {
479
  "user_id": req.user_id,
@@ -488,7 +530,7 @@ def chat_stream(req: ChatRequest):
488
  "guardrail_passed": state.get("guardrail_passed", True),
489
  "run_id": run_id,
490
  "turn_id": state["turn_id"],
491
- "eval_scores": eval_scores,
492
  }
493
  yield _sse({"type": "complete", "response": final})
494
 
@@ -503,8 +545,23 @@ def _sse(data: dict) -> str:
503
  return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
504
 
505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  @app.post("/chat/turnaround", response_model=ChatResponse)
507
- def chat_turnaround(req: TurnaroundRequest):
508
  if req.user_id not in _sessions:
509
  raise HTTPException(status_code=404, detail="no active session")
510
 
@@ -577,17 +634,20 @@ def chat_turnaround(req: TurnaroundRequest):
577
  affect_emotion = (replan_state.get("affect") or {}).get("emotion", "NEUTRAL")
578
  run_id = replan_state.get("run_id")
579
 
580
- eval_scores = _compute_and_persist_evals(
581
- run_id=run_id,
582
- user_id=req.user_id,
583
- turn_id=replan_state["turn_id"],
584
- response=replan_state["selected_response"] or "",
585
- chunks=list(replan_state.get("retrieved_chunks") or []),
586
- latency_log=dict(replan_state.get("latency_log") or {}),
587
- affect=affect_emotion,
588
- gesture_tag=replan_state.get("gesture_tag"),
589
- gaze_bucket=replan_state.get("gaze_bucket"),
590
- )
 
 
 
591
 
592
  return ChatResponse(
593
  user_id=req.user_id,
@@ -884,7 +944,7 @@ def chat_regenerate(req: RegenerateRequest):
884
  guardrail_passed=replan_state.get("guardrail_passed", True),
885
  run_id=run_id,
886
  turn_id=replan_state["turn_id"],
887
- eval_scores=eval_scores,
888
  )
889
 
890
 
 
4
  import json
5
  import logging
6
  import re
7
+ import threading
8
  import time
9
+ from collections import OrderedDict
10
  from pathlib import Path
11
 
12
+ from fastapi import BackgroundTasks, FastAPI, HTTPException
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from fastapi.responses import StreamingResponse
15
  from pydantic import BaseModel, Field
 
67
  # ── In-memory session store (replace with Redis for multi-worker deployments) ──
68
  _sessions: dict[str, dict] = {}
69
 
70
+ # Eval scores keyed by run_id, filled by a BackgroundTask after /chat returns
71
+ # so the UI can render the response immediately and poll GET /evals/{run_id}.
72
+ # Multi-worker deploys should swap this (and _sessions) for Redis.
73
+ _EVAL_FAILED: dict = {"_failed": True}
74
+ _eval_results: OrderedDict[str, dict] = OrderedDict()
75
+ _eval_lock = threading.Lock()
76
+ _EVAL_RESULTS_MAX = 200
77
+
78
+
79
+ def _remember_eval(run_id: str, scores: dict | None) -> None:
80
+ value = scores if scores else _EVAL_FAILED
81
+ with _eval_lock:
82
+ _eval_results[run_id] = value
83
+ _eval_results.move_to_end(run_id)
84
+ while len(_eval_results) > _EVAL_RESULTS_MAX:
85
+ _eval_results.popitem(last=False)
86
+
87
+
88
+ def _reserve_eval_slot(run_id: str) -> None:
89
+ """Mark a run_id as in-flight so /evals can report 'pending' vs 'unknown'."""
90
+ with _eval_lock:
91
+ if run_id not in _eval_results:
92
+ _eval_results[run_id] = {} # empty dict = pending
93
+ _eval_results.move_to_end(run_id)
94
+ while len(_eval_results) > _EVAL_RESULTS_MAX:
95
+ _eval_results.popitem(last=False)
96
+
97
 
98
  # ── Request / response schemas ─────────────────────────────────────────────────
99
 
 
353
  )
354
  except Exception:
355
  _log.exception("evals scoring failed for run %s", run_id)
356
+ _remember_eval(run_id, None)
357
  return None
358
 
359
+ _remember_eval(run_id, scores)
360
+
361
  try:
362
  entry = {
363
  "run_id": run_id,
 
377
 
378
 
379
  @app.post("/chat", response_model=ChatResponse)
380
+ def chat(req: ChatRequest, background_tasks: BackgroundTasks):
381
  guard = check_input(req.query)
382
  if not guard["allowed"]:
383
  return ChatResponse(
 
407
  affect_emotion = (result.get("affect") or {}).get("emotion", "NEUTRAL")
408
  run_id = result.get("run_id")
409
 
410
+ # Evals (NLI cross-encoder) run off the response path; UI polls /evals.
411
+ if run_id and settings.evals_enabled:
412
+ _reserve_eval_slot(run_id)
413
+ background_tasks.add_task(
414
+ _compute_and_persist_evals,
415
+ run_id=run_id,
416
+ user_id=req.user_id,
417
+ turn_id=result["turn_id"],
418
+ response=result["selected_response"] or "",
419
+ chunks=list(result.get("retrieved_chunks") or []),
420
+ latency_log=dict(result.get("latency_log") or {}),
421
+ affect=affect_emotion,
422
+ gesture_tag=req.gesture_tag,
423
+ gaze_bucket=req.gaze_bucket,
424
+ )
425
 
426
  return ChatResponse(
427
  user_id=req.user_id,
 
436
  guardrail_passed=result.get("guardrail_passed", True),
437
  run_id=run_id,
438
  turn_id=result["turn_id"],
439
+ eval_scores=None,
440
  )
441
 
442
 
 
448
  """
449
  guard = check_input(req.query)
450
  if not guard["allowed"]:
 
451
  payload = {
452
  "user_id": req.user_id,
453
  "query": req.query,
 
498
  affect_emotion = (state.get("affect") or {}).get("emotion", "NEUTRAL")
499
  run_id = state.get("run_id")
500
 
501
+ # Evals run off the response path; UI polls GET /evals/{run_id}.
502
+ if run_id and settings.evals_enabled:
503
+ _reserve_eval_slot(run_id)
504
+ threading.Thread(
505
+ target=_compute_and_persist_evals,
506
+ kwargs=dict(
507
+ run_id=run_id,
508
+ user_id=req.user_id,
509
+ turn_id=state["turn_id"],
510
+ response=state["selected_response"] or "",
511
+ chunks=list(state.get("retrieved_chunks") or []),
512
+ latency_log=dict(state.get("latency_log") or {}),
513
+ affect=affect_emotion,
514
+ gesture_tag=req.gesture_tag,
515
+ gaze_bucket=req.gaze_bucket,
516
+ ),
517
+ daemon=True,
518
+ ).start()
519
 
520
  final = {
521
  "user_id": req.user_id,
 
530
  "guardrail_passed": state.get("guardrail_passed", True),
531
  "run_id": run_id,
532
  "turn_id": state["turn_id"],
533
+ "eval_scores": None,
534
  }
535
  yield _sse({"type": "complete", "response": final})
536
 
 
545
  return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
546
 
547
 
548
+ @app.get("/evals/{run_id}")
549
+ def get_evals(run_id: str):
550
+ if not _RUN_ID_RE.match(run_id):
551
+ raise HTTPException(status_code=400, detail="invalid run_id")
552
+ with _eval_lock:
553
+ entry = _eval_results.get(run_id)
554
+ if entry is None:
555
+ return {"status": "unknown", "run_id": run_id, "eval_scores": None}
556
+ if entry is _EVAL_FAILED:
557
+ return {"status": "failed", "run_id": run_id, "eval_scores": None}
558
+ if not entry:
559
+ return {"status": "pending", "run_id": run_id, "eval_scores": None}
560
+ return {"status": "ready", "run_id": run_id, "eval_scores": entry}
561
+
562
+
563
  @app.post("/chat/turnaround", response_model=ChatResponse)
564
+ def chat_turnaround(req: TurnaroundRequest, background_tasks: BackgroundTasks):
565
  if req.user_id not in _sessions:
566
  raise HTTPException(status_code=404, detail="no active session")
567
 
 
634
  affect_emotion = (replan_state.get("affect") or {}).get("emotion", "NEUTRAL")
635
  run_id = replan_state.get("run_id")
636
 
637
+ if run_id and settings.evals_enabled:
638
+ _reserve_eval_slot(run_id)
639
+ background_tasks.add_task(
640
+ _compute_and_persist_evals,
641
+ run_id=run_id,
642
+ user_id=req.user_id,
643
+ turn_id=replan_state["turn_id"],
644
+ response=replan_state["selected_response"] or "",
645
+ chunks=list(replan_state.get("retrieved_chunks") or []),
646
+ latency_log=dict(replan_state.get("latency_log") or {}),
647
+ affect=affect_emotion,
648
+ gesture_tag=replan_state.get("gesture_tag"),
649
+ gaze_bucket=replan_state.get("gaze_bucket"),
650
+ )
651
 
652
  return ChatResponse(
653
  user_id=req.user_id,
 
944
  guardrail_passed=replan_state.get("guardrail_passed", True),
945
  run_id=run_id,
946
  turn_id=replan_state["turn_id"],
947
+ eval_scores=None,
948
  )
949
 
950
 
frontend/src/components/ChatPanel.tsx CHANGED
@@ -7,6 +7,7 @@ import type {
7
  SensingState,
8
  } from "../types";
9
  import {
 
10
  sendPick,
11
  sendTurnaround,
12
  streamChat,
@@ -141,6 +142,7 @@ export function ChatPanel({
141
  // against the new turnaround bubble's own head-signal re-firing turnaround
142
  // on itself.
143
  const turnaroundConsumedTurnRef = useRef<number | null>(null);
 
144
 
145
  useEffect(() => {
146
  bottomRef.current?.scrollIntoView({ behavior: "smooth" });
@@ -152,8 +154,39 @@ export function ChatPanel({
152
  lastTurnIdRef.current = null;
153
  turnaroundConsumedTurnRef.current = null;
154
  lastResponseTsRef.current = 0;
 
 
155
  }, [userId]);
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  const handleTurnaround = useCallback(
158
  async (reason: "head" | "manual") => {
159
  if (!userId || !backendReady || turnaroundLoading || loading) return;
@@ -188,7 +221,7 @@ export function ChatPanel({
188
  affect: res.affect,
189
  runId: res.run_id,
190
  turnId: res.turn_id,
191
- evalScores: res.eval_scores ?? null,
192
  isTurnaround: true,
193
  candidates: res.candidates ?? [],
194
  picked: true,
@@ -196,6 +229,7 @@ export function ChatPanel({
196
  return next;
197
  });
198
  onLatency(res.latency);
 
199
  // Do NOT advance lastResponseTsRef — keep the original turn's window so
200
  // the user can't head-shake the turnaround itself into another loop.
201
  } catch (e) {
@@ -223,6 +257,7 @@ export function ChatPanel({
223
  setMessages,
224
  onLatency,
225
  onHeadSignalConsumed,
 
226
  ]
227
  );
228
 
@@ -312,11 +347,12 @@ export function ChatPanel({
312
  affect: res.affect,
313
  runId: res.run_id,
314
  turnId: res.turn_id,
315
- evalScores: res.eval_scores ?? null,
316
  candidates: res.candidates ?? m.candidates ?? [],
317
  picked: false,
318
  }));
319
  onLatency(res.latency);
 
320
  }
321
  },
322
  );
@@ -337,6 +373,7 @@ export function ChatPanel({
337
  queueToken,
338
  flushNow,
339
  onLatency,
 
340
  ]
341
  );
342
 
@@ -476,12 +513,13 @@ export function ChatPanel({
476
  affect: res.affect,
477
  runId: res.run_id,
478
  turnId: res.turn_id,
479
- evalScores: res.eval_scores ?? null,
480
  candidates: res.candidates ?? m.candidates ?? [],
481
  picked: (res.candidates ?? []).length <= 1,
482
  }));
483
  onLatency(res.latency);
484
  lastResponseTsRef.current = performance.now();
 
485
  }
486
  },
487
  );
 
7
  SensingState,
8
  } from "../types";
9
  import {
10
+ pollEvals,
11
  sendPick,
12
  sendTurnaround,
13
  streamChat,
 
142
  // against the new turnaround bubble's own head-signal re-firing turnaround
143
  // on itself.
144
  const turnaroundConsumedTurnRef = useRef<number | null>(null);
145
+ const evalPollAbortsRef = useRef<Set<AbortController>>(new Set());
146
 
147
  useEffect(() => {
148
  bottomRef.current?.scrollIntoView({ behavior: "smooth" });
 
154
  lastTurnIdRef.current = null;
155
  turnaroundConsumedTurnRef.current = null;
156
  lastResponseTsRef.current = 0;
157
+ evalPollAbortsRef.current.forEach((ac) => ac.abort());
158
+ evalPollAbortsRef.current.clear();
159
  }, [userId]);
160
 
161
+ useEffect(() => {
162
+ const active = evalPollAbortsRef.current;
163
+ return () => {
164
+ active.forEach((ac) => ac.abort());
165
+ active.clear();
166
+ };
167
+ }, []);
168
+
169
+ const startEvalPolling = useCallback(
170
+ (runId: string | null | undefined) => {
171
+ if (!runId) return;
172
+ const ac = new AbortController();
173
+ evalPollAbortsRef.current.add(ac);
174
+ void pollEvals(runId, { signal: ac.signal })
175
+ .then((scores) => {
176
+ if (ac.signal.aborted || !scores) return;
177
+ setMessages((prev) =>
178
+ prev.map((m) =>
179
+ m.runId === runId ? { ...m, evalScores: scores } : m
180
+ )
181
+ );
182
+ })
183
+ .finally(() => {
184
+ evalPollAbortsRef.current.delete(ac);
185
+ });
186
+ },
187
+ [setMessages]
188
+ );
189
+
190
  const handleTurnaround = useCallback(
191
  async (reason: "head" | "manual") => {
192
  if (!userId || !backendReady || turnaroundLoading || loading) return;
 
221
  affect: res.affect,
222
  runId: res.run_id,
223
  turnId: res.turn_id,
224
+ evalScores: null,
225
  isTurnaround: true,
226
  candidates: res.candidates ?? [],
227
  picked: true,
 
229
  return next;
230
  });
231
  onLatency(res.latency);
232
+ startEvalPolling(res.run_id);
233
  // Do NOT advance lastResponseTsRef — keep the original turn's window so
234
  // the user can't head-shake the turnaround itself into another loop.
235
  } catch (e) {
 
257
  setMessages,
258
  onLatency,
259
  onHeadSignalConsumed,
260
+ startEvalPolling,
261
  ]
262
  );
263
 
 
347
  affect: res.affect,
348
  runId: res.run_id,
349
  turnId: res.turn_id,
350
+ evalScores: null,
351
  candidates: res.candidates ?? m.candidates ?? [],
352
  picked: false,
353
  }));
354
  onLatency(res.latency);
355
+ startEvalPolling(res.run_id);
356
  }
357
  },
358
  );
 
373
  queueToken,
374
  flushNow,
375
  onLatency,
376
+ startEvalPolling,
377
  ]
378
  );
379
 
 
513
  affect: res.affect,
514
  runId: res.run_id,
515
  turnId: res.turn_id,
516
+ evalScores: null,
517
  candidates: res.candidates ?? m.candidates ?? [],
518
  picked: (res.candidates ?? []).length <= 1,
519
  }));
520
  onLatency(res.latency);
521
  lastResponseTsRef.current = performance.now();
522
+ startEvalPolling(res.run_id);
523
  }
524
  },
525
  );
frontend/src/lib/api.ts CHANGED
@@ -1,6 +1,7 @@
1
  import type {
2
  ChatRequest,
3
  ChatResponse,
 
4
  Persona,
5
  TurnaroundRequest,
6
  } from "../types";
@@ -140,6 +141,56 @@ export async function sendPick(args: {
140
  if (!res.ok) throw new Error(`API error: ${res.status}`);
141
  }
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  export async function submitRating(args: {
144
  run_id: string;
145
  user_id: string;
 
1
  import type {
2
  ChatRequest,
3
  ChatResponse,
4
+ EvalScores,
5
  Persona,
6
  TurnaroundRequest,
7
  } from "../types";
 
141
  if (!res.ok) throw new Error(`API error: ${res.status}`);
142
  }
143
 
144
+ export type EvalsStatus = "pending" | "ready" | "failed" | "unknown";
145
+ export interface EvalsFetchResult {
146
+ status: EvalsStatus;
147
+ run_id: string;
148
+ eval_scores: EvalScores | null;
149
+ }
150
+
151
+ export async function fetchEvals(runId: string): Promise<EvalsFetchResult> {
152
+ const res = await fetch(`${API_BASE}/evals/${encodeURIComponent(runId)}`);
153
+ if (!res.ok) throw new Error(`API error: ${res.status}`);
154
+ return res.json();
155
+ }
156
+
157
+ export async function pollEvals(
158
+ runId: string,
159
+ opts: {
160
+ initialDelayMs?: number;
161
+ maxDelayMs?: number;
162
+ timeoutMs?: number;
163
+ signal?: AbortSignal;
164
+ } = {}
165
+ ): Promise<EvalScores | null> {
166
+ const maxDelayMs = opts.maxDelayMs ?? 2000;
167
+ const timeoutMs = opts.timeoutMs ?? 20000;
168
+ let delay = opts.initialDelayMs ?? 300;
169
+ const start = performance.now();
170
+ // Track consecutive "unknown" responses so transient race conditions (poll
171
+ // racing the server picking up the new run_id) don't immediately give up.
172
+ let unknownStreak = 0;
173
+ while (performance.now() - start < timeoutMs) {
174
+ if (opts.signal?.aborted) return null;
175
+ try {
176
+ const r = await fetchEvals(runId);
177
+ if (r.status === "ready") return r.eval_scores;
178
+ if (r.status === "failed") return null;
179
+ if (r.status === "unknown") {
180
+ unknownStreak += 1;
181
+ if (unknownStreak >= 3) return null;
182
+ } else {
183
+ unknownStreak = 0;
184
+ }
185
+ } catch (e) {
186
+ console.warn("pollEvals: transient error", e);
187
+ }
188
+ await new Promise((res) => setTimeout(res, delay));
189
+ delay = Math.min(delay * 2, maxDelayMs);
190
+ }
191
+ return null;
192
+ }
193
+
194
  export async function submitRating(args: {
195
  run_id: string;
196
  user_id: string;
frontend/vite.config.ts CHANGED
@@ -11,6 +11,9 @@ export default defineConfig({
11
  "/users": "http://localhost:8000",
12
  "/session": "http://localhost:8000",
13
  "/health": "http://localhost:8000",
 
 
 
14
  },
15
  },
16
  })
 
11
  "/users": "http://localhost:8000",
12
  "/session": "http://localhost:8000",
13
  "/health": "http://localhost:8000",
14
+ "/evals": "http://localhost:8000",
15
+ "/feedback": "http://localhost:8000",
16
+ "/debug": "http://localhost:8000",
17
  },
18
  },
19
  })