akseljoonas HF Staff commited on
Commit
7d17616
·
1 Parent(s): ecbfd3c

feat: cooperative cancellation, session persistence, background generation

Browse files
backend/routes/agent.py CHANGED
@@ -298,6 +298,18 @@ async def interrupt_session(
298
  return {"status": "interrupted", "session_id": session_id}
299
 
300
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  @router.post("/undo/{session_id}")
302
  async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict:
303
  """Undo the last turn in a session."""
 
298
  return {"status": "interrupted", "session_id": session_id}
299
 
300
 
301
+ @router.get("/session/{session_id}/messages")
302
+ async def get_session_messages(
303
+ session_id: str, user: dict = Depends(get_current_user)
304
+ ) -> list[dict]:
305
+ """Return the session's message history from memory."""
306
+ _check_session_access(session_id, user)
307
+ agent_session = session_manager.sessions.get(session_id)
308
+ if not agent_session or not agent_session.is_active:
309
+ raise HTTPException(status_code=404, detail="Session not found or inactive")
310
+ return [msg.model_dump() for msg in agent_session.session.context_manager.items]
311
+
312
+
313
  @router.post("/undo/{session_id}")
314
  async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict:
315
  """Undo the last turn in a session."""
frontend/src/components/Layout/AppLayout.tsx CHANGED
@@ -110,7 +110,7 @@ export default function AppLayout() {
110
 
111
  const hasAnySessions = sessions.length > 0;
112
 
113
- const { messages, sendMessage, stop, undoLastTurn, approveTools } = useAgentChat({
114
  sessionId: activeSessionId,
115
  onReady: () => logger.log('Agent ready'),
116
  onError: (error) => logger.error('Agent error:', error),
@@ -209,7 +209,7 @@ export default function AppLayout() {
209
  },
210
  }}
211
  >
212
- <SessionSidebar onClose={handleSidebarClose} />
213
  </Drawer>
214
  );
215
 
 
110
 
111
  const hasAnySessions = sessions.length > 0;
112
 
113
+ const { messages, sendMessage, stop, undoLastTurn, approveTools, flushMessages } = useAgentChat({
114
  sessionId: activeSessionId,
115
  onReady: () => logger.log('Agent ready'),
116
  onError: (error) => logger.error('Agent error:', error),
 
209
  },
210
  }}
211
  >
212
+ <SessionSidebar onClose={handleSidebarClose} onBeforeSwitch={flushMessages} />
213
  </Drawer>
214
  );
215
 
frontend/src/components/SessionSidebar/SessionSidebar.tsx CHANGED
@@ -16,6 +16,7 @@ import { apiFetch } from '@/utils/api';
16
 
17
  interface SessionSidebarProps {
18
  onClose?: () => void;
 
19
  }
20
 
21
  /** Small coloured dot for connection status */
@@ -32,7 +33,7 @@ const StatusDot = ({ connected }: { connected: boolean }) => (
32
  />
33
  );
34
 
35
- export default function SessionSidebar({ onClose }: SessionSidebarProps) {
36
  const { sessions, activeSessionId, createSession, deleteSession, switchSession } =
37
  useSessionStore();
38
  const { isConnected, setPlan, clearPanel } =
@@ -44,6 +45,7 @@ export default function SessionSidebar({ onClose }: SessionSidebarProps) {
44
 
45
  const handleNewSession = useCallback(async () => {
46
  if (isCreatingSession) return;
 
47
  setIsCreatingSession(true);
48
  setCapacityError(null);
49
  try {
@@ -63,7 +65,7 @@ export default function SessionSidebar({ onClose }: SessionSidebarProps) {
63
  } finally {
64
  setIsCreatingSession(false);
65
  }
66
- }, [isCreatingSession, createSession, setPlan, clearPanel, onClose]);
67
 
68
  const handleDelete = useCallback(
69
  async (sessionId: string, e: React.MouseEvent) => {
@@ -81,12 +83,13 @@ export default function SessionSidebar({ onClose }: SessionSidebarProps) {
81
 
82
  const handleSelect = useCallback(
83
  (sessionId: string) => {
 
84
  switchSession(sessionId);
85
  setPlan([]);
86
  clearPanel();
87
  onClose?.();
88
  },
89
- [switchSession, setPlan, clearPanel, onClose],
90
  );
91
 
92
  const formatTime = (d: string) =>
@@ -270,7 +273,7 @@ export default function SessionSidebar({ onClose }: SessionSidebarProps) {
270
  )}
271
  </Box>
272
 
273
- {/* ── Footer: New Session + status ──────────────────────────── */}
274
  <Divider sx={{ opacity: 0.5 }} />
275
  <Box
276
  sx={{
@@ -319,7 +322,7 @@ export default function SessionSidebar({ onClose }: SessionSidebarProps) {
319
  ) : (
320
  <>
321
  <AddIcon sx={{ fontSize: 16 }} />
322
- New Session
323
  </>
324
  )}
325
  </Box>
 
16
 
17
  interface SessionSidebarProps {
18
  onClose?: () => void;
19
+ onBeforeSwitch?: () => void;
20
  }
21
 
22
  /** Small coloured dot for connection status */
 
33
  />
34
  );
35
 
36
+ export default function SessionSidebar({ onClose, onBeforeSwitch }: SessionSidebarProps) {
37
  const { sessions, activeSessionId, createSession, deleteSession, switchSession } =
38
  useSessionStore();
39
  const { isConnected, setPlan, clearPanel } =
 
45
 
46
  const handleNewSession = useCallback(async () => {
47
  if (isCreatingSession) return;
48
+ onBeforeSwitch?.();
49
  setIsCreatingSession(true);
50
  setCapacityError(null);
51
  try {
 
65
  } finally {
66
  setIsCreatingSession(false);
67
  }
68
+ }, [isCreatingSession, onBeforeSwitch, createSession, setPlan, clearPanel, onClose]);
69
 
70
  const handleDelete = useCallback(
71
  async (sessionId: string, e: React.MouseEvent) => {
 
83
 
84
  const handleSelect = useCallback(
85
  (sessionId: string) => {
86
+ onBeforeSwitch?.();
87
  switchSession(sessionId);
88
  setPlan([]);
89
  clearPanel();
90
  onClose?.();
91
  },
92
+ [onBeforeSwitch, switchSession, setPlan, clearPanel, onClose],
93
  );
94
 
95
  const formatTime = (d: string) =>
 
273
  )}
274
  </Box>
275
 
276
+ {/* ── Footer: New Task + status ──────────────────────────── */}
277
  <Divider sx={{ opacity: 0.5 }} />
278
  <Box
279
  sx={{
 
322
  ) : (
323
  <>
324
  <AddIcon sx={{ fontSize: 16 }} />
325
+ New Task
326
  </>
327
  )}
328
  </Box>
frontend/src/hooks/useAgentChat.ts CHANGED
@@ -8,6 +8,7 @@ import { useChat } from '@ai-sdk/react';
8
  import type { UIMessage } from 'ai';
9
  import { WebSocketChatTransport, type SideChannelCallbacks } from '@/lib/ws-chat-transport';
10
  import { loadMessages, saveMessages } from '@/lib/chat-message-store';
 
11
  import { apiFetch } from '@/utils/api';
12
  import { useAgentStore } from '@/store/agentStore';
13
  import { useSessionStore } from '@/store/sessionStore';
@@ -210,12 +211,6 @@ export function useAgentChat({ sessionId, onReady, onError, onSessionDead }: Use
210
  messages: initialMessages,
211
  transport: transportRef.current!,
212
  experimental_throttle: 80,
213
- onFinish: ({ messages, isAbort, isError }) => {
214
- if (isAbort || isError) return;
215
- if (sessionId && messages.length > 0) {
216
- saveMessages(sessionId, messages);
217
- }
218
- },
219
  onError: (error) => {
220
  logger.error('useChat error:', error);
221
  setError(error.message);
@@ -227,7 +222,30 @@ export function useAgentChat({ sessionId, onReady, onError, onSessionDead }: Use
227
  chatActionsRef.current.setMessages = chat.setMessages;
228
  chatActionsRef.current.messages = chat.messages;
229
 
230
- // ── Persist messages on every user send (onFinish covers assistant turns) ──
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  const prevLenRef = useRef(initialMessages.length);
232
  useEffect(() => {
233
  if (!sessionId || chat.messages.length === 0) return;
@@ -267,6 +285,12 @@ export function useAgentChat({ sessionId, onReady, onError, onSessionDead }: Use
267
  [sessionId, setProcessing],
268
  );
269
 
 
 
 
 
 
 
270
  return {
271
  messages: chat.messages,
272
  sendMessage: chat.sendMessage,
@@ -274,6 +298,7 @@ export function useAgentChat({ sessionId, onReady, onError, onSessionDead }: Use
274
  status: chat.status,
275
  undoLastTurn,
276
  approveTools,
 
277
  transport: transportRef.current,
278
  };
279
  }
 
8
  import type { UIMessage } from 'ai';
9
  import { WebSocketChatTransport, type SideChannelCallbacks } from '@/lib/ws-chat-transport';
10
  import { loadMessages, saveMessages } from '@/lib/chat-message-store';
11
+ import { llmMessagesToUIMessages } from '@/lib/convert-llm-messages';
12
  import { apiFetch } from '@/utils/api';
13
  import { useAgentStore } from '@/store/agentStore';
14
  import { useSessionStore } from '@/store/sessionStore';
 
211
  messages: initialMessages,
212
  transport: transportRef.current!,
213
  experimental_throttle: 80,
 
 
 
 
 
 
214
  onError: (error) => {
215
  logger.error('useChat error:', error);
216
  setError(error.message);
 
222
  chatActionsRef.current.setMessages = chat.setMessages;
223
  chatActionsRef.current.messages = chat.messages;
224
 
225
+ // ── Hydrate from backend when switching to a session ──────────────
226
+ useEffect(() => {
227
+ if (!sessionId) return;
228
+ let cancelled = false;
229
+ apiFetch(`/api/session/${sessionId}/messages`)
230
+ .then((res) => (res.ok ? res.json() : null))
231
+ .then((data) => {
232
+ if (cancelled || !data || !Array.isArray(data) || data.length === 0) return;
233
+ const uiMsgs = llmMessagesToUIMessages(data);
234
+ if (uiMsgs.length > 0) {
235
+ chat.setMessages(uiMsgs);
236
+ saveMessages(sessionId, uiMsgs);
237
+ }
238
+ })
239
+ .catch(() => { /* backend unreachable — localStorage fallback is fine */ });
240
+ return () => { cancelled = true; };
241
+ }, [sessionId]); // eslint-disable-line react-hooks/exhaustive-deps
242
+
243
+ // ── Persist messages ──────────────────────────────────────────────
244
+ const flushRef = useRef<{ sid: string | null; msgs: UIMessage[] }>({ sid: null, msgs: [] });
245
+ flushRef.current.sid = sessionId;
246
+ flushRef.current.msgs = chat.messages;
247
+
248
+ // Save whenever message count changes (covers user sends + new assistant msgs)
249
  const prevLenRef = useRef(initialMessages.length);
250
  useEffect(() => {
251
  if (!sessionId || chat.messages.length === 0) return;
 
285
  [sessionId, setProcessing],
286
  );
287
 
288
+ // ── Flush current messages to localStorage (call before switching sessions) ──
289
+ const flushMessages = useCallback(() => {
290
+ const { sid, msgs } = flushRef.current;
291
+ if (sid && msgs.length > 0) saveMessages(sid, msgs);
292
+ }, []);
293
+
294
  return {
295
  messages: chat.messages,
296
  sendMessage: chat.sendMessage,
 
298
  status: chat.status,
299
  undoLastTurn,
300
  approveTools,
301
+ flushMessages,
302
  transport: transportRef.current,
303
  };
304
  }
frontend/src/lib/convert-llm-messages.ts ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Convert backend LLM messages (litellm format) to Vercel AI SDK UIMessage format.
3
+ */
4
+ import type { UIMessage } from 'ai';
5
+
6
+ interface LLMToolCall {
7
+ id: string;
8
+ function: { name: string; arguments: string };
9
+ }
10
+
11
+ interface LLMMessage {
12
+ role: 'user' | 'assistant' | 'tool' | 'system';
13
+ content: string | null;
14
+ tool_calls?: LLMToolCall[] | null;
15
+ tool_call_id?: string | null;
16
+ name?: string | null;
17
+ }
18
+
19
+ let idCounter = 0;
20
+ function nextId(): string {
21
+ return `msg-${Date.now()}-${++idCounter}`;
22
+ }
23
+
24
+ export function llmMessagesToUIMessages(messages: LLMMessage[]): UIMessage[] {
25
+ // Build a map of tool_call_id -> tool result for pairing
26
+ const toolResults = new Map<string, { output: string; isError: boolean }>();
27
+ for (const msg of messages) {
28
+ if (msg.role === 'tool' && msg.tool_call_id) {
29
+ toolResults.set(msg.tool_call_id, {
30
+ output: msg.content || '',
31
+ isError: false,
32
+ });
33
+ }
34
+ }
35
+
36
+ const uiMessages: UIMessage[] = [];
37
+
38
+ for (const msg of messages) {
39
+ if (msg.role === 'system') continue;
40
+ if (msg.role === 'tool') continue; // handled via tool_calls pairing
41
+
42
+ if (msg.role === 'user') {
43
+ uiMessages.push({
44
+ id: nextId(),
45
+ role: 'user',
46
+ parts: [{ type: 'text', text: msg.content || '' }],
47
+ });
48
+ continue;
49
+ }
50
+
51
+ if (msg.role === 'assistant') {
52
+ const parts: UIMessage['parts'] = [];
53
+
54
+ if (msg.content) {
55
+ parts.push({ type: 'text', text: msg.content });
56
+ }
57
+
58
+ if (msg.tool_calls) {
59
+ for (const tc of msg.tool_calls) {
60
+ let input: Record<string, unknown> = {};
61
+ try {
62
+ input = JSON.parse(tc.function.arguments);
63
+ } catch { /* malformed */ }
64
+
65
+ const result = toolResults.get(tc.id);
66
+ if (result) {
67
+ parts.push({
68
+ type: 'dynamic-tool',
69
+ toolCallId: tc.id,
70
+ toolName: tc.function.name,
71
+ state: 'output-available',
72
+ input,
73
+ output: result.output,
74
+ });
75
+ } else {
76
+ parts.push({
77
+ type: 'dynamic-tool',
78
+ toolCallId: tc.id,
79
+ toolName: tc.function.name,
80
+ state: 'input-available',
81
+ input,
82
+ });
83
+ }
84
+ }
85
+ }
86
+
87
+ uiMessages.push({
88
+ id: nextId(),
89
+ role: 'assistant',
90
+ parts,
91
+ });
92
+ }
93
+ }
94
+
95
+ return uiMessages;
96
+ }
frontend/src/lib/ws-chat-transport.ts CHANGED
@@ -66,6 +66,9 @@ export class WebSocketChatTransport implements ChatTransport<UIMessage> {
66
  private currentSessionId: string | null = null;
67
  private sideChannel: SideChannelCallbacks;
68
 
 
 
 
69
  private streamController: ReadableStreamDefaultController<UIMessageChunk> | null = null;
70
  private streamGeneration = 0;
71
  private abortedGeneration = 0;
@@ -130,9 +133,56 @@ export class WebSocketChatTransport implements ChatTransport<UIMessage> {
130
  clearTimeout(this.connectTimeout);
131
  this.connectTimeout = null;
132
  }
133
- this.disconnectWebSocket();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  this.currentSessionId = sessionId;
135
  if (sessionId) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  this.retries = 0;
137
  this.reconnectDelay = WS_RECONNECT_DELAY;
138
  this.connectTimeout = setTimeout(() => {
@@ -171,6 +221,9 @@ export class WebSocketChatTransport implements ChatTransport<UIMessage> {
171
  document.removeEventListener('visibilitychange', this.boundVisibilityHandler);
172
  this.boundVisibilityHandler = null;
173
  }
 
 
 
174
  this.disconnectWebSocket();
175
  this.closeActiveStream();
176
  }
 
66
  private currentSessionId: string | null = null;
67
  private sideChannel: SideChannelCallbacks;
68
 
69
+ /** Background WebSockets kept alive so the backend agent keeps running. */
70
+ private backgroundSockets: Map<string, WebSocket> = new Map();
71
+
72
  private streamController: ReadableStreamDefaultController<UIMessageChunk> | null = null;
73
  private streamGeneration = 0;
74
  private abortedGeneration = 0;
 
133
  clearTimeout(this.connectTimeout);
134
  this.connectTimeout = null;
135
  }
136
+
137
+ // Move current WS to background instead of closing it
138
+ if (this.ws && this.currentSessionId && this.currentSessionId !== sessionId) {
139
+ const oldId = this.currentSessionId;
140
+ const oldWs = this.ws;
141
+ this.backgroundSockets.set(oldId, oldWs);
142
+ // Replace handler: background sockets only need ping/pong
143
+ oldWs.onmessage = (evt) => {
144
+ try {
145
+ const raw = JSON.parse(evt.data);
146
+ if (raw.type === 'pong') return;
147
+ // Silently discard — backend keeps running, we'll load results from localStorage
148
+ } catch { /* ignore */ }
149
+ };
150
+ oldWs.onclose = () => {
151
+ this.backgroundSockets.delete(oldId);
152
+ };
153
+ this.ws = null;
154
+ this.stopPing();
155
+ } else {
156
+ this.disconnectWebSocket();
157
+ }
158
+
159
  this.currentSessionId = sessionId;
160
  if (sessionId) {
161
+ // Promote background socket if one exists for this session
162
+ const bg = this.backgroundSockets.get(sessionId);
163
+ if (bg && (bg.readyState === WebSocket.OPEN || bg.readyState === WebSocket.CONNECTING)) {
164
+ this.backgroundSockets.delete(sessionId);
165
+ this.ws = bg;
166
+ // Restore full event handling
167
+ bg.onmessage = (evt) => {
168
+ try {
169
+ const raw = JSON.parse(evt.data);
170
+ if (raw.type === 'pong') return;
171
+ this.handleEvent(raw as AgentEvent);
172
+ } catch (e) {
173
+ logger.error('WS parse error:', e);
174
+ }
175
+ };
176
+ bg.onclose = (evt) => {
177
+ logger.log('WS closed', evt.code, evt.reason);
178
+ this.sideChannel.onConnectionChange(false);
179
+ this.stopPing();
180
+ };
181
+ this.sideChannel.onConnectionChange(true);
182
+ this.startPing();
183
+ return;
184
+ }
185
+
186
  this.retries = 0;
187
  this.reconnectDelay = WS_RECONNECT_DELAY;
188
  this.connectTimeout = setTimeout(() => {
 
221
  document.removeEventListener('visibilitychange', this.boundVisibilityHandler);
222
  this.boundVisibilityHandler = null;
223
  }
224
+ // Close all background sockets
225
+ for (const ws of this.backgroundSockets.values()) ws.close();
226
+ this.backgroundSockets.clear();
227
  this.disconnectWebSocket();
228
  this.closeActiveStream();
229
  }