akseljoonas HF Staff commited on
Commit
288473a
·
1 Parent(s): 5ddd727

fix: correct context_length init and emit tool_call events for malformed calls

Browse files

- Initialize context_length to max_context (not token-estimate of system prompt)
and reserve 10k token buffer to prevent overflows
- Emit synthetic tool_call events before tool_output errors for malformed calls
so the frontend renders matching dynamic-tool parts

agent/context_manager/manager.py CHANGED
@@ -85,9 +85,9 @@ class ContextManager:
85
  prompt_file_suffix="system_prompt_v3.yaml",
86
  hf_token=hf_token,
87
  )
88
- self.max_context = max_context
89
  self.compact_size = int(max_context * compact_size)
90
- self.context_length = len(self.system_prompt) // 4
91
  self.untouched_messages = untouched_messages
92
  self.items: list[Message] = [Message(role="system", content=self.system_prompt)]
93
 
@@ -160,8 +160,7 @@ class ContextManager:
160
  if not needs_fix:
161
  return
162
  msg.tool_calls = [
163
- tc if not isinstance(tc, dict) else ToolCall(**tc)
164
- for tc in tool_calls
165
  ]
166
 
167
  def recover_malformed_tool_calls(self) -> set[str]:
@@ -214,7 +213,9 @@ class ContextManager:
214
  except (json.JSONDecodeError, TypeError, ValueError) as e:
215
  logger.warning(
216
  "Malformed arguments for tool_call %s (%s): %s",
217
- tc.id, tc.function.name, e,
 
 
218
  )
219
  tc.function.arguments = "{}"
220
  malformed_ids.add(tc.id)
@@ -268,7 +269,9 @@ class ContextManager:
268
  assistant_msg = None
269
  for i in range(len(self.items) - 1, -1, -1):
270
  msg = self.items[i]
271
- if getattr(msg, "role", None) == "assistant" and getattr(msg, "tool_calls", None):
 
 
272
  assistant_msg = msg
273
  break
274
  # Stop scanning once we hit a user message — anything before
 
85
  prompt_file_suffix="system_prompt_v3.yaml",
86
  hf_token=hf_token,
87
  )
88
+ self.max_context = max_context - 10000
89
  self.compact_size = int(max_context * compact_size)
90
+ self.context_length = max_context
91
  self.untouched_messages = untouched_messages
92
  self.items: list[Message] = [Message(role="system", content=self.system_prompt)]
93
 
 
160
  if not needs_fix:
161
  return
162
  msg.tool_calls = [
163
+ tc if not isinstance(tc, dict) else ToolCall(**tc) for tc in tool_calls
 
164
  ]
165
 
166
  def recover_malformed_tool_calls(self) -> set[str]:
 
213
  except (json.JSONDecodeError, TypeError, ValueError) as e:
214
  logger.warning(
215
  "Malformed arguments for tool_call %s (%s): %s",
216
+ tc.id,
217
+ tc.function.name,
218
+ e,
219
  )
220
  tc.function.arguments = "{}"
221
  malformed_ids.add(tc.id)
 
269
  assistant_msg = None
270
  for i in range(len(self.items) - 1, -1, -1):
271
  msg = self.items[i]
272
+ if getattr(msg, "role", None) == "assistant" and getattr(
273
+ msg, "tool_calls", None
274
+ ):
275
  assistant_msg = msg
276
  break
277
  # Stop scanning once we hit a user message — anything before
agent/core/agent_loop.py CHANGED
@@ -38,7 +38,9 @@ def _resolve_hf_router_params(model_name: str) -> dict:
38
  if not model_name.startswith("huggingface/"):
39
  return {"model": model_name}
40
 
41
- parts = model_name.split("/", 2) # ['huggingface', 'novita', 'moonshotai/kimi-k2.5']
 
 
42
  if len(parts) < 3:
43
  return {"model": model_name}
44
 
@@ -163,8 +165,6 @@ async def _compact_and_notify(session: Session) -> None:
163
  )
164
 
165
 
166
-
167
-
168
  class Handlers:
169
  """Handler functions for each operation type"""
170
 
@@ -179,7 +179,9 @@ class Handlers:
179
  tool_calls = session.pending_approval.get("tool_calls", [])
180
  for tc in tool_calls:
181
  tool_name = tc.function.name
182
- abandon_msg = "Task abandoned — user continued the conversation without approving."
 
 
183
 
184
  # Keep LLM context valid: every tool_call needs a tool result
185
  tool_msg = Message(
@@ -372,21 +374,40 @@ class Handlers:
372
  # Recover any malformed tool calls (sanitize JSON + inject
373
  # error results). Returns IDs to skip during execution.
374
  malformed_ids = session.context_manager.recover_malformed_tool_calls()
375
- for mid in malformed_ids:
376
- await session.send_event(
377
- Event(
378
- event_type="tool_output",
379
- data={
380
- "tool": next(
381
- (tc.function.name for tc in tool_calls if tc.id == mid),
382
- "unknown",
383
- ),
384
- "tool_call_id": mid,
385
- "output": "Malformed tool call — see error in context.",
386
- "success": False,
387
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  )
389
- )
390
 
391
  # Separate tools into those requiring approval and those that don't
392
  approval_required_tools = []
@@ -499,10 +520,15 @@ class Handlers:
499
 
500
  # Resolve sandbox file paths for hf_jobs scripts so the
501
  # frontend can display & edit the actual file content.
502
- if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str):
 
 
503
  from agent.tools.sandbox_tool import resolve_sandbox_script
 
504
  sandbox = getattr(session, "sandbox", None)
505
- content, _ = await resolve_sandbox_script(sandbox, tool_args["script"])
 
 
506
  if content:
507
  tool_args = {**tool_args, "script": content}
508
 
@@ -604,7 +630,9 @@ class Handlers:
604
  approval_map = {a["tool_call_id"]: a for a in approvals}
605
  for a in approvals:
606
  if a.get("edited_script"):
607
- logger.info(f"Received edited script for tool_call {a['tool_call_id']} ({len(a['edited_script'])} chars)")
 
 
608
 
609
  # Separate approved and rejected tool calls
610
  approved_tasks = []
@@ -750,7 +778,9 @@ class Handlers:
750
  # Ensure feedback is a string and sanitize any problematic characters
751
  feedback_str = str(user_feedback).strip()
752
  # Remove any control characters that might break JSON parsing
753
- feedback_str = "".join(char for char in feedback_str if ord(char) >= 32 or char in "\n\t")
 
 
754
  rejection_msg += f". User feedback: {feedback_str}"
755
 
756
  # Ensure rejection_msg is a clean string
@@ -844,7 +874,11 @@ async def submission_loop(
844
  """
845
 
846
  # Create session with tool router
847
- session = Session(event_queue, config=config, tool_router=tool_router)
 
 
 
 
848
  logger.info("Agent loop started")
849
 
850
  # Retry any failed uploads from previous sessions (fire-and-forget)
 
38
  if not model_name.startswith("huggingface/"):
39
  return {"model": model_name}
40
 
41
+ parts = model_name.split(
42
+ "/", 2
43
+ ) # ['huggingface', 'novita', 'moonshotai/kimi-k2.5']
44
  if len(parts) < 3:
45
  return {"model": model_name}
46
 
 
165
  )
166
 
167
 
 
 
168
  class Handlers:
169
  """Handler functions for each operation type"""
170
 
 
179
  tool_calls = session.pending_approval.get("tool_calls", [])
180
  for tc in tool_calls:
181
  tool_name = tc.function.name
182
+ abandon_msg = (
183
+ "Task abandoned — user continued the conversation without approving."
184
+ )
185
 
186
  # Keep LLM context valid: every tool_call needs a tool result
187
  tool_msg = Message(
 
374
  # Recover any malformed tool calls (sanitize JSON + inject
375
  # error results). Returns IDs to skip during execution.
376
  malformed_ids = session.context_manager.recover_malformed_tool_calls()
377
+ if malformed_ids:
378
+ # For each malformed tool_call, emit a synthetic tool_call +
379
+ # tool_output-error pair so the frontend has a matching
380
+ # dynamic-tool part instead of an orphan error.
381
+ for tc in tool_calls:
382
+ if tc.id not in malformed_ids:
383
+ continue
384
+ tool_name = tc.function.name
385
+ try:
386
+ tool_args = json.loads(tc.function.arguments)
387
+ except (json.JSONDecodeError, TypeError, ValueError):
388
+ tool_args = {}
389
+
390
+ await session.send_event(
391
+ Event(
392
+ event_type="tool_call",
393
+ data={
394
+ "tool": tool_name,
395
+ "arguments": tool_args,
396
+ "tool_call_id": tc.id,
397
+ },
398
+ )
399
+ )
400
+ await session.send_event(
401
+ Event(
402
+ event_type="tool_output",
403
+ data={
404
+ "tool": tool_name,
405
+ "tool_call_id": tc.id,
406
+ "output": "Malformed tool call — see error in context.",
407
+ "success": False,
408
+ },
409
+ )
410
  )
 
411
 
412
  # Separate tools into those requiring approval and those that don't
413
  approval_required_tools = []
 
520
 
521
  # Resolve sandbox file paths for hf_jobs scripts so the
522
  # frontend can display & edit the actual file content.
523
+ if tool_name == "hf_jobs" and isinstance(
524
+ tool_args.get("script"), str
525
+ ):
526
  from agent.tools.sandbox_tool import resolve_sandbox_script
527
+
528
  sandbox = getattr(session, "sandbox", None)
529
+ content, _ = await resolve_sandbox_script(
530
+ sandbox, tool_args["script"]
531
+ )
532
  if content:
533
  tool_args = {**tool_args, "script": content}
534
 
 
630
  approval_map = {a["tool_call_id"]: a for a in approvals}
631
  for a in approvals:
632
  if a.get("edited_script"):
633
+ logger.info(
634
+ f"Received edited script for tool_call {a['tool_call_id']} ({len(a['edited_script'])} chars)"
635
+ )
636
 
637
  # Separate approved and rejected tool calls
638
  approved_tasks = []
 
778
  # Ensure feedback is a string and sanitize any problematic characters
779
  feedback_str = str(user_feedback).strip()
780
  # Remove any control characters that might break JSON parsing
781
+ feedback_str = "".join(
782
+ char for char in feedback_str if ord(char) >= 32 or char in "\n\t"
783
+ )
784
  rejection_msg += f". User feedback: {feedback_str}"
785
 
786
  # Ensure rejection_msg is a clean string
 
874
  """
875
 
876
  # Create session with tool router
877
+ session = Session(
878
+ event_queue, config=config, tool_router=tool_router, hf_token=hf_token
879
+ )
880
+ if session_holder is not None:
881
+ session_holder[0] = session
882
  logger.info("Agent loop started")
883
 
884
  # Retry any failed uploads from previous sessions (fire-and-forget)