akseljoonas HF Staff commited on
Commit
6d4d388
·
1 Parent(s): 8d40f7c

Add message edit and regenerate from any point in conversation

Browse files

Backend: add truncate_to_user_message() on ContextManager, POST
/api/truncate/{session_id} endpoint, and session_manager.truncate()
helper to slice conversation history to before a given user message.

Frontend: add editAndRegenerate callback in useAgentChat that
truncates backend + frontend messages then resubmits via existing
chat flow. UserMessage gains an inline edit mode with pencil icon
on hover, TextField for editing, Enter/Escape to confirm/cancel.

agent/context_manager/manager.py CHANGED
@@ -243,6 +243,25 @@ class ContextManager:
243
 
244
  return False
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  async def compact(
247
  self, model_name: str, tool_specs: list[dict] | None = None
248
  ) -> None:
 
243
 
244
  return False
245
 
246
+ def truncate_to_user_message(self, user_message_index: int) -> bool:
247
+ """Truncate history to just before the Nth user message (0-indexed).
248
+
249
+ Removes that user message and everything after it.
250
+ System message (index 0) is never removed.
251
+
252
+ Returns True if the target user message was found and removed.
253
+ """
254
+ count = 0
255
+ for i, msg in enumerate(self.items):
256
+ if i == 0:
257
+ continue # skip system message
258
+ if getattr(msg, "role", None) == "user":
259
+ if count == user_message_index:
260
+ self.items = self.items[:i]
261
+ return True
262
+ count += 1
263
+ return False
264
+
265
  async def compact(
266
  self, model_name: str, tool_specs: list[dict] | None = None
267
  ) -> None:
backend/models.py CHANGED
@@ -54,6 +54,12 @@ class SubmitRequest(BaseModel):
54
  text: str
55
 
56
 
 
 
 
 
 
 
57
  class SessionResponse(BaseModel):
58
  """Response when creating a new session."""
59
 
 
54
  text: str
55
 
56
 
57
+ class TruncateRequest(BaseModel):
58
+ """Request to truncate conversation history to before a specific user message."""
59
+
60
+ user_message_index: int
61
+
62
+
63
  class SessionResponse(BaseModel):
64
  """Response when creating a new session."""
65
 
backend/routes/agent.py CHANGED
@@ -26,6 +26,7 @@ from models import (
26
  SessionInfo,
27
  SessionResponse,
28
  SubmitRequest,
 
29
  )
30
  from session_manager import MAX_SESSIONS, SessionCapacityError, session_manager
31
 
@@ -438,6 +439,18 @@ async def undo_session(session_id: str, user: dict = Depends(get_current_user))
438
  return {"status": "undo_requested", "session_id": session_id}
439
 
440
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  @router.post("/compact/{session_id}")
442
  async def compact_session(
443
  session_id: str, user: dict = Depends(get_current_user)
 
26
  SessionInfo,
27
  SessionResponse,
28
  SubmitRequest,
29
+ TruncateRequest,
30
  )
31
  from session_manager import MAX_SESSIONS, SessionCapacityError, session_manager
32
 
 
439
  return {"status": "undo_requested", "session_id": session_id}
440
 
441
 
442
+ @router.post("/truncate/{session_id}")
443
+ async def truncate_session(
444
+ session_id: str, body: TruncateRequest, user: dict = Depends(get_current_user)
445
+ ) -> dict:
446
+ """Truncate conversation to before a specific user message."""
447
+ _check_session_access(session_id, user)
448
+ success = await session_manager.truncate(session_id, body.user_message_index)
449
+ if not success:
450
+ raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range")
451
+ return {"status": "truncated", "session_id": session_id}
452
+
453
+
454
  @router.post("/compact/{session_id}")
455
  async def compact_session(
456
  session_id: str, user: dict = Depends(get_current_user)
backend/session_manager.py CHANGED
@@ -319,6 +319,14 @@ class SessionManager:
319
  operation = Operation(op_type=OpType.UNDO)
320
  return await self.submit(session_id, operation)
321
 
 
 
 
 
 
 
 
 
322
  async def compact(self, session_id: str) -> bool:
323
  """Compact context in a session."""
324
  operation = Operation(op_type=OpType.COMPACT)
 
319
  operation = Operation(op_type=OpType.UNDO)
320
  return await self.submit(session_id, operation)
321
 
322
+ async def truncate(self, session_id: str, user_message_index: int) -> bool:
323
+ """Truncate conversation to before a specific user message (direct, no queue)."""
324
+ async with self._lock:
325
+ agent_session = self.sessions.get(session_id)
326
+ if not agent_session or not agent_session.is_active:
327
+ return False
328
+ return agent_session.session.context_manager.truncate_to_user_message(user_message_index)
329
+
330
  async def compact(self, session_id: str) -> bool:
331
  """Compact context in a session."""
332
  operation = Operation(op_type=OpType.COMPACT)
frontend/src/components/Chat/MessageBubble.tsx CHANGED
@@ -6,6 +6,7 @@ interface MessageBubbleProps {
6
  message: UIMessage;
7
  isLastTurn?: boolean;
8
  onUndoTurn?: () => void;
 
9
  isProcessing?: boolean;
10
  isStreaming?: boolean;
11
  approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
@@ -15,6 +16,7 @@ export default function MessageBubble({
15
  message,
16
  isLastTurn = false,
17
  onUndoTurn,
 
18
  isProcessing = false,
19
  isStreaming = false,
20
  approveTools,
@@ -25,6 +27,7 @@ export default function MessageBubble({
25
  message={message}
26
  isLastTurn={isLastTurn}
27
  onUndoTurn={onUndoTurn}
 
28
  isProcessing={isProcessing}
29
  />
30
  );
 
6
  message: UIMessage;
7
  isLastTurn?: boolean;
8
  onUndoTurn?: () => void;
9
+ onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise<void>;
10
  isProcessing?: boolean;
11
  isStreaming?: boolean;
12
  approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
 
16
  message,
17
  isLastTurn = false,
18
  onUndoTurn,
19
+ onEditAndRegenerate,
20
  isProcessing = false,
21
  isStreaming = false,
22
  approveTools,
 
27
  message={message}
28
  isLastTurn={isLastTurn}
29
  onUndoTurn={onUndoTurn}
30
+ onEditAndRegenerate={onEditAndRegenerate}
31
  isProcessing={isProcessing}
32
  />
33
  );
frontend/src/components/Chat/MessageList.tsx CHANGED
@@ -10,6 +10,7 @@ interface MessageListProps {
10
  isProcessing: boolean;
11
  approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
12
  onUndoLastTurn: () => void | Promise<void>;
 
13
  }
14
 
15
  function getGreeting(): string {
@@ -56,7 +57,7 @@ function WelcomeGreeting() {
56
  );
57
  }
58
 
59
- export default function MessageList({ messages, isProcessing, approveTools, onUndoLastTurn }: MessageListProps) {
60
  const scrollContainerRef = useRef<HTMLDivElement>(null);
61
  const stickToBottom = useRef(true);
62
 
@@ -135,6 +136,7 @@ export default function MessageList({ messages, isProcessing, approveTools, onUn
135
  message={msg}
136
  isLastTurn={msg.id === lastUserMsgId}
137
  onUndoTurn={onUndoLastTurn}
 
138
  isProcessing={isProcessing}
139
  isStreaming={isProcessing && msg.id === lastAssistantId}
140
  approveTools={approveTools}
 
10
  isProcessing: boolean;
11
  approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
12
  onUndoLastTurn: () => void | Promise<void>;
13
+ onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise<void>;
14
  }
15
 
16
  function getGreeting(): string {
 
57
  );
58
  }
59
 
60
+ export default function MessageList({ messages, isProcessing, approveTools, onUndoLastTurn, onEditAndRegenerate }: MessageListProps) {
61
  const scrollContainerRef = useRef<HTMLDivElement>(null);
62
  const stickToBottom = useRef(true);
63
 
 
136
  message={msg}
137
  isLastTurn={msg.id === lastUserMsgId}
138
  onUndoTurn={onUndoLastTurn}
139
+ onEditAndRegenerate={onEditAndRegenerate}
140
  isProcessing={isProcessing}
141
  isStreaming={isProcessing && msg.id === lastAssistantId}
142
  approveTools={approveTools}
frontend/src/components/Chat/UserMessage.tsx CHANGED
@@ -1,5 +1,8 @@
1
- import { Box, Stack, Typography, IconButton, Tooltip } from '@mui/material';
 
2
  import CloseIcon from '@mui/icons-material/Close';
 
 
3
  import type { UIMessage } from 'ai';
4
  import type { MessageMeta } from '@/types/agent';
5
 
@@ -7,6 +10,7 @@ interface UserMessageProps {
7
  message: UIMessage;
8
  isLastTurn?: boolean;
9
  onUndoTurn?: () => void;
 
10
  isProcessing?: boolean;
11
  }
12
 
@@ -21,14 +25,57 @@ export default function UserMessage({
21
  message,
22
  isLastTurn = false,
23
  onUndoTurn,
 
24
  isProcessing = false,
25
  }: UserMessageProps) {
26
  const showUndo = isLastTurn && !isProcessing && !!onUndoTurn;
 
27
  const text = extractText(message);
28
  const meta = message.metadata as MessageMeta | undefined;
29
  const timeStr = meta?.createdAt
30
  ? new Date(meta.createdAt).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' })
31
  : null;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  return (
33
  <Stack
34
  direction="row"
@@ -36,35 +83,56 @@ export default function UserMessage({
36
  justifyContent="flex-end"
37
  alignItems="flex-start"
38
  sx={{
39
- '& .undo-btn': {
40
  opacity: 0,
41
  transition: 'opacity 0.15s ease',
42
  },
43
- '&:hover .undo-btn': {
44
  opacity: 1,
45
  },
46
  }}
47
  >
48
- {showUndo && (
49
- <Box className="undo-btn" sx={{ display: 'flex', alignItems: 'center', mt: 0.75 }}>
50
- <Tooltip title="Remove this turn" placement="left">
51
- <IconButton
52
- onClick={onUndoTurn}
53
- size="small"
54
- sx={{
55
- width: 24,
56
- height: 24,
57
- color: 'var(--muted-text)',
58
- '&:hover': {
59
- color: 'var(--accent-red)',
60
- bgcolor: 'rgba(244,67,54,0.08)',
61
- },
62
- }}
63
- >
64
- <CloseIcon sx={{ fontSize: 14 }} />
65
- </IconButton>
66
- </Tooltip>
67
- </Box>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  )}
69
 
70
  <Box
@@ -78,20 +146,66 @@ export default function UserMessage({
78
  border: '1px solid var(--border)',
79
  }}
80
  >
81
- <Typography
82
- variant="body1"
83
- sx={{
84
- fontSize: '0.925rem',
85
- lineHeight: 1.65,
86
- color: 'var(--text)',
87
- whiteSpace: 'pre-wrap',
88
- wordBreak: 'break-word',
89
- }}
90
- >
91
- {text}
92
- </Typography>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- {timeStr && (
95
  <Typography
96
  variant="caption"
97
  sx={{ color: 'var(--muted-text)', mt: 0.5, display: 'block', textAlign: 'right', fontSize: '0.7rem' }}
 
1
+ import { useState, useRef, useEffect } from 'react';
2
+ import { Box, Stack, Typography, IconButton, Tooltip, TextField } from '@mui/material';
3
  import CloseIcon from '@mui/icons-material/Close';
4
+ import EditIcon from '@mui/icons-material/Edit';
5
+ import CheckIcon from '@mui/icons-material/Check';
6
  import type { UIMessage } from 'ai';
7
  import type { MessageMeta } from '@/types/agent';
8
 
 
10
  message: UIMessage;
11
  isLastTurn?: boolean;
12
  onUndoTurn?: () => void;
13
+ onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise<void>;
14
  isProcessing?: boolean;
15
  }
16
 
 
25
  message,
26
  isLastTurn = false,
27
  onUndoTurn,
28
+ onEditAndRegenerate,
29
  isProcessing = false,
30
  }: UserMessageProps) {
31
  const showUndo = isLastTurn && !isProcessing && !!onUndoTurn;
32
+ const showEdit = !isProcessing && !!onEditAndRegenerate;
33
  const text = extractText(message);
34
  const meta = message.metadata as MessageMeta | undefined;
35
  const timeStr = meta?.createdAt
36
  ? new Date(meta.createdAt).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' })
37
  : null;
38
+
39
+ const [isEditing, setIsEditing] = useState(false);
40
+ const [editText, setEditText] = useState(text);
41
+ const inputRef = useRef<HTMLTextAreaElement>(null);
42
+
43
+ useEffect(() => {
44
+ if (isEditing && inputRef.current) {
45
+ inputRef.current.focus();
46
+ inputRef.current.selectionStart = inputRef.current.value.length;
47
+ }
48
+ }, [isEditing]);
49
+
50
+ const handleStartEdit = () => {
51
+ setEditText(text);
52
+ setIsEditing(true);
53
+ };
54
+
55
+ const handleConfirmEdit = () => {
56
+ const trimmed = editText.trim();
57
+ if (!trimmed || trimmed === text) {
58
+ setIsEditing(false);
59
+ return;
60
+ }
61
+ setIsEditing(false);
62
+ onEditAndRegenerate?.(message.id, trimmed);
63
+ };
64
+
65
+ const handleCancelEdit = () => {
66
+ setIsEditing(false);
67
+ setEditText(text);
68
+ };
69
+
70
+ const handleKeyDown = (e: React.KeyboardEvent) => {
71
+ if (e.key === 'Enter' && !e.shiftKey) {
72
+ e.preventDefault();
73
+ handleConfirmEdit();
74
+ } else if (e.key === 'Escape') {
75
+ handleCancelEdit();
76
+ }
77
+ };
78
+
79
  return (
80
  <Stack
81
  direction="row"
 
83
  justifyContent="flex-end"
84
  alignItems="flex-start"
85
  sx={{
86
+ '& .action-btn': {
87
  opacity: 0,
88
  transition: 'opacity 0.15s ease',
89
  },
90
+ '&:hover .action-btn': {
91
  opacity: 1,
92
  },
93
  }}
94
  >
95
+ {!isEditing && (showUndo || showEdit) && (
96
+ <Stack className="action-btn" direction="row" spacing={0.25} sx={{ mt: 0.75 }}>
97
+ {showEdit && (
98
+ <Tooltip title="Edit & regenerate" placement="left">
99
+ <IconButton
100
+ onClick={handleStartEdit}
101
+ size="small"
102
+ sx={{
103
+ width: 24,
104
+ height: 24,
105
+ color: 'var(--muted-text)',
106
+ '&:hover': {
107
+ color: 'var(--accent-yellow)',
108
+ bgcolor: 'rgba(255,157,0,0.08)',
109
+ },
110
+ }}
111
+ >
112
+ <EditIcon sx={{ fontSize: 14 }} />
113
+ </IconButton>
114
+ </Tooltip>
115
+ )}
116
+ {showUndo && (
117
+ <Tooltip title="Remove this turn" placement="left">
118
+ <IconButton
119
+ onClick={onUndoTurn}
120
+ size="small"
121
+ sx={{
122
+ width: 24,
123
+ height: 24,
124
+ color: 'var(--muted-text)',
125
+ '&:hover': {
126
+ color: 'var(--accent-red)',
127
+ bgcolor: 'rgba(244,67,54,0.08)',
128
+ },
129
+ }}
130
+ >
131
+ <CloseIcon sx={{ fontSize: 14 }} />
132
+ </IconButton>
133
+ </Tooltip>
134
+ )}
135
+ </Stack>
136
  )}
137
 
138
  <Box
 
146
  border: '1px solid var(--border)',
147
  }}
148
  >
149
+ {isEditing ? (
150
+ <Stack spacing={1}>
151
+ <TextField
152
+ inputRef={inputRef}
153
+ multiline
154
+ fullWidth
155
+ value={editText}
156
+ onChange={(e) => setEditText(e.target.value)}
157
+ onKeyDown={handleKeyDown}
158
+ variant="outlined"
159
+ size="small"
160
+ sx={{
161
+ '& .MuiOutlinedInput-root': {
162
+ fontFamily: 'inherit',
163
+ fontSize: '0.925rem',
164
+ lineHeight: 1.65,
165
+ color: 'var(--text)',
166
+ '& fieldset': { borderColor: 'var(--accent-yellow)', borderWidth: 1.5 },
167
+ '&:hover fieldset': { borderColor: 'var(--accent-yellow)' },
168
+ '&.Mui-focused fieldset': { borderColor: 'var(--accent-yellow)' },
169
+ },
170
+ }}
171
+ />
172
+ <Stack direction="row" spacing={0.5} justifyContent="flex-end">
173
+ <Tooltip title="Cancel (Esc)">
174
+ <IconButton
175
+ onClick={handleCancelEdit}
176
+ size="small"
177
+ sx={{ color: 'var(--muted-text)', '&:hover': { color: 'var(--accent-red)' } }}
178
+ >
179
+ <CloseIcon sx={{ fontSize: 16 }} />
180
+ </IconButton>
181
+ </Tooltip>
182
+ <Tooltip title="Confirm (Enter)">
183
+ <IconButton
184
+ onClick={handleConfirmEdit}
185
+ size="small"
186
+ sx={{ color: 'var(--accent-green)', '&:hover': { bgcolor: 'rgba(47,204,113,0.1)' } }}
187
+ >
188
+ <CheckIcon sx={{ fontSize: 16 }} />
189
+ </IconButton>
190
+ </Tooltip>
191
+ </Stack>
192
+ </Stack>
193
+ ) : (
194
+ <Typography
195
+ variant="body1"
196
+ sx={{
197
+ fontSize: '0.925rem',
198
+ lineHeight: 1.65,
199
+ color: 'var(--text)',
200
+ whiteSpace: 'pre-wrap',
201
+ wordBreak: 'break-word',
202
+ }}
203
+ >
204
+ {text}
205
+ </Typography>
206
+ )}
207
 
208
+ {timeStr && !isEditing && (
209
  <Typography
210
  variant="caption"
211
  sx={{ color: 'var(--muted-text)', mt: 0.5, display: 'block', textAlign: 'right', fontSize: '0.7rem' }}
frontend/src/components/SessionChat.tsx CHANGED
@@ -24,7 +24,7 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
24
  const { isConnected, isProcessing, activityStatus, updateSession } = useAgentStore();
25
  const { updateSessionTitle } = useSessionStore();
26
 
27
- const { messages, sendMessage, stop, status, undoLastTurn, approveTools } = useAgentChat({
28
  sessionId,
29
  isActive,
30
  onReady: () => logger.log(`Session ${sessionId} ready`),
@@ -102,6 +102,7 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
102
  isProcessing={busy}
103
  approveTools={approveTools}
104
  onUndoLastTurn={undoLastTurn}
 
105
  />
106
  <ChatInput
107
  onSend={handleSendMessage}
 
24
  const { isConnected, isProcessing, activityStatus, updateSession } = useAgentStore();
25
  const { updateSessionTitle } = useSessionStore();
26
 
27
+ const { messages, sendMessage, stop, status, undoLastTurn, editAndRegenerate, approveTools } = useAgentChat({
28
  sessionId,
29
  isActive,
30
  onReady: () => logger.log(`Session ${sessionId} ready`),
 
102
  isProcessing={busy}
103
  approveTools={approveTools}
104
  onUndoLastTurn={undoLastTurn}
105
+ onEditAndRegenerate={editAndRegenerate}
106
  />
107
  <ChatInput
108
  onSend={handleSendMessage}
frontend/src/hooks/useAgentChat.ts CHANGED
@@ -640,12 +640,52 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
640
  apiFetch(`/api/interrupt/${sessionId}`, { method: 'POST' }).catch(() => {});
641
  }, [sessionId, updateSession]);
642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  return {
644
  messages: chat.messages,
645
  sendMessage: chat.sendMessage,
646
  stop,
647
  status: chat.status,
648
  undoLastTurn,
 
649
  approveTools,
650
  };
651
  }
 
640
  apiFetch(`/api/interrupt/${sessionId}`, { method: 'POST' }).catch(() => {});
641
  }, [sessionId, updateSession]);
642
 
643
+ // -- Edit message + regenerate from that point ----------------------------
644
+ const editAndRegenerate = useCallback(async (messageId: string, newText: string) => {
645
+ try {
646
+ const msgs = chatActionsRef.current.messages;
647
+ const setMsgs = chatActionsRef.current.setMessages;
648
+ if (!setMsgs) return;
649
+
650
+ // Find the target message and compute user message index (0-indexed, skipping system)
651
+ const msgIndex = msgs.findIndex(m => m.id === messageId);
652
+ if (msgIndex < 0) return;
653
+
654
+ let userMsgIndex = 0;
655
+ for (let i = 0; i < msgIndex; i++) {
656
+ if (msgs[i].role === 'user') userMsgIndex++;
657
+ }
658
+
659
+ // 1. Truncate backend history
660
+ const res = await apiFetch(`/api/truncate/${sessionId}`, {
661
+ method: 'POST',
662
+ body: JSON.stringify({ user_message_index: userMsgIndex }),
663
+ headers: { 'Content-Type': 'application/json' },
664
+ });
665
+ if (!res.ok) {
666
+ logger.error('Truncate API returned', res.status);
667
+ return;
668
+ }
669
+
670
+ // 2. Truncate frontend messages
671
+ const truncated = msgs.slice(0, msgIndex);
672
+ setMsgs(truncated);
673
+ saveMessages(sessionId, truncated);
674
+
675
+ // 3. Send the edited message (reuses existing transport + /api/chat)
676
+ chat.sendMessage({ text: newText, metadata: { createdAt: new Date().toISOString() } });
677
+ } catch (e) {
678
+ logger.error('Edit and regenerate failed:', e);
679
+ }
680
+ }, [sessionId, chat]);
681
+
682
  return {
683
  messages: chat.messages,
684
  sendMessage: chat.sendMessage,
685
  stop,
686
  status: chat.status,
687
  undoLastTurn,
688
+ editAndRegenerate,
689
  approveTools,
690
  };
691
  }