akseljoonas HF Staff Claude Opus 4.6 (1M context) commited on
Commit
263e229
·
1 Parent(s): 2383b4b

feat: LLM-based pruning for complex tool outputs (hf_jobs)

Browse files

Tools in _LLM_PRUNE_TOOLS get a cheap LLM call (≤600 tokens) that
extracts the answer from the raw output, preserving job IDs, metrics,
and error messages exactly. Falls back to deterministic truncation if
the LLM call fails.

Three tool lists now control pruning behavior:
- _PRUNE_SKIP_TOOLS: never pruned (research, plan_tool)
- _LLM_PRUNE_TOOLS: LLM-summarised (hf_jobs)
- Everything else: deterministic 80-char truncation

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

agent/context_manager/manager.py CHANGED
@@ -224,17 +224,38 @@ class ContextManager:
224
  # Tools whose outputs should never be pruned (too valuable to summarise)
225
  _PRUNE_SKIP_TOOLS: set[str] = {"research", "plan_tool"}
226
 
227
- def prune_old_tool_outputs(self) -> None:
228
- """Stage 1 compaction: deterministically truncate old tool outputs.
229
-
230
- For any tool message older than the last 6 messages, replace content
231
- exceeding 500 chars with a short one-line summary preserving
232
- tool_call_id and name.
 
 
 
 
 
 
 
233
  """
234
  if len(self.items) <= 6:
235
  return
236
 
237
  cutoff = len(self.items) - 6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  for i in range(cutoff - 1, -1, -1):
239
  msg = self.items[i]
240
  if getattr(msg, "role", None) != "tool":
@@ -246,18 +267,55 @@ class ContextManager:
246
  tool_name = getattr(msg, "name", None) or "tool"
247
  if tool_name in self._PRUNE_SKIP_TOOLS:
248
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  preview = content[:80]
250
  total = len(content)
251
 
252
- if tool_name == "hf_jobs":
253
- summary = f"[hf_jobs: {preview}... ({total} chars)]"
254
- elif tool_name == "bash":
255
- # Try to extract exit_code from content
256
  exit_code_part = ""
257
  if "exit_code" in content[:200]:
258
  for line in content[:200].splitlines():
259
  if "exit_code" in line:
260
- exit_code_part = f"exit_code visible if present, "
261
  break
262
  summary = f"[bash: {exit_code_part}{preview}... ({total} chars)]"
263
  else:
@@ -269,7 +327,7 @@ class ContextManager:
269
  self, model_name: str, tool_specs: list[dict] | None = None
270
  ) -> None:
271
  """Remove old messages to keep history under target size"""
272
- self.prune_old_tool_outputs()
273
 
274
  if (self.context_length <= self.max_context) or not self.items:
275
  return
 
224
  # Tools whose outputs should never be pruned (too valuable to summarise)
225
  _PRUNE_SKIP_TOOLS: set[str] = {"research", "plan_tool"}
226
 
227
+ # Tools whose outputs are pruned via a cheap LLM call instead of
228
+ # deterministic truncation (the output structure is too complex for
229
+ # a fixed head-slice to capture the answer reliably).
230
+ _LLM_PRUNE_TOOLS: set[str] = {"hf_jobs"}
231
+
232
+ async def prune_old_tool_outputs(self, model_name: str | None = None) -> None:
233
+ """Stage 1 compaction: shrink old tool outputs.
234
+
235
+ For any tool message older than the last 6 messages whose content
236
+ exceeds 500 chars:
237
+ - Tools in _LLM_PRUNE_TOOLS get a cheap LLM summarisation (≤600 tokens).
238
+ - All other tools get a deterministic one-line summary.
239
+ tool_call_id and name are always preserved.
240
  """
241
  if len(self.items) <= 6:
242
  return
243
 
244
  cutoff = len(self.items) - 6
245
+
246
+ # Find the preceding assistant tool_call arguments so the LLM
247
+ # knows what question the tool output was answering.
248
+ def _find_tool_call_args(tool_call_id: str) -> str | None:
249
+ for msg in self.items:
250
+ if getattr(msg, "role", None) != "assistant":
251
+ continue
252
+ for tc in getattr(msg, "tool_calls", None) or []:
253
+ tc_id = tc.id if hasattr(tc, "id") else tc.get("id")
254
+ if tc_id == tool_call_id:
255
+ fn = tc.function if hasattr(tc, "function") else tc.get("function", {})
256
+ return fn.arguments if hasattr(fn, "arguments") else fn.get("arguments", "")
257
+ return None
258
+
259
  for i in range(cutoff - 1, -1, -1):
260
  msg = self.items[i]
261
  if getattr(msg, "role", None) != "tool":
 
267
  tool_name = getattr(msg, "name", None) or "tool"
268
  if tool_name in self._PRUNE_SKIP_TOOLS:
269
  continue
270
+
271
+ # --- LLM-based pruning for complex tool outputs ---
272
+ if tool_name in self._LLM_PRUNE_TOOLS and model_name:
273
+ call_args = _find_tool_call_args(getattr(msg, "tool_call_id", ""))
274
+ context_line = (
275
+ f"The tool was called with: {call_args}\n\n" if call_args else ""
276
+ )
277
+ try:
278
+ hf_key = os.environ.get("INFERENCE_TOKEN")
279
+ resp = await acompletion(
280
+ model=model_name,
281
+ messages=[
282
+ Message(
283
+ role="user",
284
+ content=(
285
+ f"{context_line}"
286
+ f"Below is the raw output of the '{tool_name}' tool.\n"
287
+ "Give the answer to the original request unchanged — "
288
+ "preserve all job IDs, numbers, status values, error "
289
+ "messages, and metrics exactly. Omit filler/boilerplate. "
290
+ "Stay under 600 tokens.\n\n"
291
+ f"{content}"
292
+ ),
293
+ )
294
+ ],
295
+ max_completion_tokens=600,
296
+ api_key=hf_key
297
+ if hf_key and model_name.startswith("huggingface/")
298
+ else None,
299
+ )
300
+ msg.content = resp.choices[0].message.content
301
+ continue
302
+ except Exception:
303
+ logger.warning(
304
+ "LLM prune failed for %s, falling back to deterministic",
305
+ tool_name,
306
+ )
307
+ # fall through to deterministic pruning below
308
+
309
+ # --- Deterministic pruning ---
310
  preview = content[:80]
311
  total = len(content)
312
 
313
+ if tool_name == "bash":
 
 
 
314
  exit_code_part = ""
315
  if "exit_code" in content[:200]:
316
  for line in content[:200].splitlines():
317
  if "exit_code" in line:
318
+ exit_code_part = "exit_code visible if present, "
319
  break
320
  summary = f"[bash: {exit_code_part}{preview}... ({total} chars)]"
321
  else:
 
327
  self, model_name: str, tool_specs: list[dict] | None = None
328
  ) -> None:
329
  """Remove old messages to keep history under target size"""
330
+ await self.prune_old_tool_outputs(model_name=model_name)
331
 
332
  if (self.context_length <= self.max_context) or not self.items:
333
  return
agent/core/agent_loop.py CHANGED
@@ -199,7 +199,9 @@ def _is_transient_error(error: Exception) -> bool:
199
 
200
  async def _compact_and_notify(session: Session) -> None:
201
  """Run compaction and send event if context was reduced."""
202
- session.context_manager.prune_old_tool_outputs()
 
 
203
  old_length = session.context_manager.context_length
204
  tool_specs = session.tool_router.get_tool_specs_for_llm()
205
  await session.context_manager.compact(
 
199
 
200
  async def _compact_and_notify(session: Session) -> None:
201
  """Run compaction and send event if context was reduced."""
202
+ await session.context_manager.prune_old_tool_outputs(
203
+ model_name=session.config.model_name,
204
+ )
205
  old_length = session.context_manager.context_length
206
  tool_specs = session.tool_router.get_tool_specs_for_llm()
207
  await session.context_manager.compact(