akseljoonas HF Staff akseljoonas commited on
Commit
f729708
·
1 Parent(s): 392de34

fix: normalize dict tool_calls from litellm streaming to proper ToolCall objects

Browse files

litellm's streaming handler uses .dict() internally (Pydantic v2 deprecated
path), and because Message has validate_assignment=False by default, direct
attribute assignment can leave tool_calls as raw dicts instead of
ChatCompletionMessageToolCall objects. Add _normalize_tool_calls() to
convert them back before sanitization and dangling-call patching.

Co-Authored-By: akseljoonas <akseljoonas@users.noreply.github.com>

Files changed (1) hide show
  1. agent/context_manager/manager.py +27 -35
agent/context_manager/manager.py CHANGED
@@ -142,35 +142,25 @@ class ContextManager:
142
  return self.items
143
 
144
  @staticmethod
145
- def _tc_id(tc) -> str:
146
- if isinstance(tc, dict):
147
- return tc.get("id", "")
148
- return tc.id
149
 
150
- @staticmethod
151
- def _tc_func_name(tc) -> str:
152
- if isinstance(tc, dict):
153
- fn = tc.get("function", {})
154
- return fn.get("name", "") if isinstance(fn, dict) else getattr(fn, "name", "")
155
- return tc.function.name
156
-
157
- @staticmethod
158
- def _tc_func_args(tc) -> str:
159
- if isinstance(tc, dict):
160
- fn = tc.get("function", {})
161
- return fn.get("arguments", "{}") if isinstance(fn, dict) else getattr(fn, "arguments", "{}")
162
- return tc.function.arguments
163
 
164
- @staticmethod
165
- def _tc_set_func_args(tc, value: str) -> None:
166
- if isinstance(tc, dict):
167
- fn = tc.get("function")
168
- if isinstance(fn, dict):
169
- fn["arguments"] = value
170
- else:
171
- fn.arguments = value
172
- else:
173
- tc.function.arguments = value
174
 
175
  def _sanitize_tool_calls(self) -> None:
176
  """Fix malformed tool_call arguments across all assistant messages."""
@@ -182,15 +172,17 @@ class ContextManager:
182
  tool_calls = getattr(msg, "tool_calls", None)
183
  if not tool_calls:
184
  continue
185
- for tc in tool_calls:
 
 
186
  try:
187
- json.loads(self._tc_func_args(tc))
188
  except (json.JSONDecodeError, TypeError, ValueError):
189
  logger.warning(
190
  "Sanitizing malformed arguments for tool_call %s (%s)",
191
- self._tc_id(tc), self._tc_func_name(tc),
192
  )
193
- self._tc_set_func_args(tc, "{}")
194
  def _patch_dangling_tool_calls(self) -> None:
195
  """Add stub tool results for any tool_calls that lack a matching result.
196
 
@@ -215,20 +207,20 @@ class ContextManager:
215
  if not assistant_msg:
216
  return
217
 
 
218
  answered_ids = {
219
  getattr(m, "tool_call_id", None)
220
  for m in self.items
221
  if getattr(m, "role", None) == "tool"
222
  }
223
  for tc in assistant_msg.tool_calls:
224
- tc_id = self._tc_id(tc)
225
- if tc_id not in answered_ids:
226
  self.items.append(
227
  Message(
228
  role="tool",
229
  content="Tool was not executed (interrupted or error).",
230
- tool_call_id=tc_id,
231
- name=self._tc_func_name(tc),
232
  )
233
  )
234
 
 
142
  return self.items
143
 
144
  @staticmethod
145
+ def _normalize_tool_calls(msg: Message) -> None:
146
+ """Ensure msg.tool_calls contains proper ToolCall objects, not dicts.
 
 
147
 
148
+ litellm's Message has validate_assignment=False (Pydantic v2 default),
149
+ so direct attribute assignment (e.g. inside litellm's streaming handler)
150
+ can leave raw dicts. Re-assigning via the constructor fixes this.
151
+ """
152
+ from litellm import ChatCompletionMessageToolCall as ToolCall
 
 
 
 
 
 
 
 
153
 
154
+ tool_calls = getattr(msg, "tool_calls", None)
155
+ if not tool_calls:
156
+ return
157
+ needs_fix = any(isinstance(tc, dict) for tc in tool_calls)
158
+ if not needs_fix:
159
+ return
160
+ msg.tool_calls = [
161
+ tc if not isinstance(tc, dict) else ToolCall(**tc)
162
+ for tc in tool_calls
163
+ ]
164
 
165
  def _sanitize_tool_calls(self) -> None:
166
  """Fix malformed tool_call arguments across all assistant messages."""
 
172
  tool_calls = getattr(msg, "tool_calls", None)
173
  if not tool_calls:
174
  continue
175
+ # Ensure proper ToolCall objects (litellm streaming can leave dicts)
176
+ self._normalize_tool_calls(msg)
177
+ for tc in msg.tool_calls:
178
  try:
179
+ json.loads(tc.function.arguments)
180
  except (json.JSONDecodeError, TypeError, ValueError):
181
  logger.warning(
182
  "Sanitizing malformed arguments for tool_call %s (%s)",
183
+ tc.id, tc.function.name,
184
  )
185
+ tc.function.arguments = "{}"
186
  def _patch_dangling_tool_calls(self) -> None:
187
  """Add stub tool results for any tool_calls that lack a matching result.
188
 
 
207
  if not assistant_msg:
208
  return
209
 
210
+ self._normalize_tool_calls(assistant_msg)
211
  answered_ids = {
212
  getattr(m, "tool_call_id", None)
213
  for m in self.items
214
  if getattr(m, "role", None) == "tool"
215
  }
216
  for tc in assistant_msg.tool_calls:
217
+ if tc.id not in answered_ids:
 
218
  self.items.append(
219
  Message(
220
  role="tool",
221
  content="Tool was not executed (interrupted or error).",
222
+ tool_call_id=tc.id,
223
+ name=tc.function.name,
224
  )
225
  )
226