Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Commit ·
263e229
1
Parent(s): 2383b4b
feat: LLM-based pruning for complex tool outputs (hf_jobs)
Browse filesTools in _LLM_PRUNE_TOOLS get a cheap LLM call (≤600 tokens) that
extracts the answer from the raw output, preserving job IDs, metrics,
and error messages exactly. Falls back to deterministic truncation if
the LLM call fails.
Three tool lists now control pruning behavior:
- _PRUNE_SKIP_TOOLS: never pruned (research, plan_tool)
- _LLM_PRUNE_TOOLS: LLM-summarised (hf_jobs)
- Everything else: deterministic 80-char truncation
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- agent/context_manager/manager.py +70 -12
- agent/core/agent_loop.py +3 -1
agent/context_manager/manager.py
CHANGED
|
@@ -224,17 +224,38 @@ class ContextManager:
|
|
| 224 |
# Tools whose outputs should never be pruned (too valuable to summarise)
|
| 225 |
_PRUNE_SKIP_TOOLS: set[str] = {"research", "plan_tool"}
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
"""
|
| 234 |
if len(self.items) <= 6:
|
| 235 |
return
|
| 236 |
|
| 237 |
cutoff = len(self.items) - 6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
for i in range(cutoff - 1, -1, -1):
|
| 239 |
msg = self.items[i]
|
| 240 |
if getattr(msg, "role", None) != "tool":
|
|
@@ -246,18 +267,55 @@ class ContextManager:
|
|
| 246 |
tool_name = getattr(msg, "name", None) or "tool"
|
| 247 |
if tool_name in self._PRUNE_SKIP_TOOLS:
|
| 248 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
preview = content[:80]
|
| 250 |
total = len(content)
|
| 251 |
|
| 252 |
-
if tool_name == "
|
| 253 |
-
summary = f"[hf_jobs: {preview}... ({total} chars)]"
|
| 254 |
-
elif tool_name == "bash":
|
| 255 |
-
# Try to extract exit_code from content
|
| 256 |
exit_code_part = ""
|
| 257 |
if "exit_code" in content[:200]:
|
| 258 |
for line in content[:200].splitlines():
|
| 259 |
if "exit_code" in line:
|
| 260 |
-
exit_code_part =
|
| 261 |
break
|
| 262 |
summary = f"[bash: {exit_code_part}{preview}... ({total} chars)]"
|
| 263 |
else:
|
|
@@ -269,7 +327,7 @@ class ContextManager:
|
|
| 269 |
self, model_name: str, tool_specs: list[dict] | None = None
|
| 270 |
) -> None:
|
| 271 |
"""Remove old messages to keep history under target size"""
|
| 272 |
-
self.prune_old_tool_outputs()
|
| 273 |
|
| 274 |
if (self.context_length <= self.max_context) or not self.items:
|
| 275 |
return
|
|
|
|
| 224 |
# Tools whose outputs should never be pruned (too valuable to summarise)
|
| 225 |
_PRUNE_SKIP_TOOLS: set[str] = {"research", "plan_tool"}
|
| 226 |
|
| 227 |
+
# Tools whose outputs are pruned via a cheap LLM call instead of
|
| 228 |
+
# deterministic truncation (the output structure is too complex for
|
| 229 |
+
# a fixed head-slice to capture the answer reliably).
|
| 230 |
+
_LLM_PRUNE_TOOLS: set[str] = {"hf_jobs"}
|
| 231 |
+
|
| 232 |
+
async def prune_old_tool_outputs(self, model_name: str | None = None) -> None:
|
| 233 |
+
"""Stage 1 compaction: shrink old tool outputs.
|
| 234 |
+
|
| 235 |
+
For any tool message older than the last 6 messages whose content
|
| 236 |
+
exceeds 500 chars:
|
| 237 |
+
- Tools in _LLM_PRUNE_TOOLS get a cheap LLM summarisation (≤600 tokens).
|
| 238 |
+
- All other tools get a deterministic one-line summary.
|
| 239 |
+
tool_call_id and name are always preserved.
|
| 240 |
"""
|
| 241 |
if len(self.items) <= 6:
|
| 242 |
return
|
| 243 |
|
| 244 |
cutoff = len(self.items) - 6
|
| 245 |
+
|
| 246 |
+
# Find the preceding assistant tool_call arguments so the LLM
|
| 247 |
+
# knows what question the tool output was answering.
|
| 248 |
+
def _find_tool_call_args(tool_call_id: str) -> str | None:
|
| 249 |
+
for msg in self.items:
|
| 250 |
+
if getattr(msg, "role", None) != "assistant":
|
| 251 |
+
continue
|
| 252 |
+
for tc in getattr(msg, "tool_calls", None) or []:
|
| 253 |
+
tc_id = tc.id if hasattr(tc, "id") else tc.get("id")
|
| 254 |
+
if tc_id == tool_call_id:
|
| 255 |
+
fn = tc.function if hasattr(tc, "function") else tc.get("function", {})
|
| 256 |
+
return fn.arguments if hasattr(fn, "arguments") else fn.get("arguments", "")
|
| 257 |
+
return None
|
| 258 |
+
|
| 259 |
for i in range(cutoff - 1, -1, -1):
|
| 260 |
msg = self.items[i]
|
| 261 |
if getattr(msg, "role", None) != "tool":
|
|
|
|
| 267 |
tool_name = getattr(msg, "name", None) or "tool"
|
| 268 |
if tool_name in self._PRUNE_SKIP_TOOLS:
|
| 269 |
continue
|
| 270 |
+
|
| 271 |
+
# --- LLM-based pruning for complex tool outputs ---
|
| 272 |
+
if tool_name in self._LLM_PRUNE_TOOLS and model_name:
|
| 273 |
+
call_args = _find_tool_call_args(getattr(msg, "tool_call_id", ""))
|
| 274 |
+
context_line = (
|
| 275 |
+
f"The tool was called with: {call_args}\n\n" if call_args else ""
|
| 276 |
+
)
|
| 277 |
+
try:
|
| 278 |
+
hf_key = os.environ.get("INFERENCE_TOKEN")
|
| 279 |
+
resp = await acompletion(
|
| 280 |
+
model=model_name,
|
| 281 |
+
messages=[
|
| 282 |
+
Message(
|
| 283 |
+
role="user",
|
| 284 |
+
content=(
|
| 285 |
+
f"{context_line}"
|
| 286 |
+
f"Below is the raw output of the '{tool_name}' tool.\n"
|
| 287 |
+
"Give the answer to the original request unchanged — "
|
| 288 |
+
"preserve all job IDs, numbers, status values, error "
|
| 289 |
+
"messages, and metrics exactly. Omit filler/boilerplate. "
|
| 290 |
+
"Stay under 600 tokens.\n\n"
|
| 291 |
+
f"{content}"
|
| 292 |
+
),
|
| 293 |
+
)
|
| 294 |
+
],
|
| 295 |
+
max_completion_tokens=600,
|
| 296 |
+
api_key=hf_key
|
| 297 |
+
if hf_key and model_name.startswith("huggingface/")
|
| 298 |
+
else None,
|
| 299 |
+
)
|
| 300 |
+
msg.content = resp.choices[0].message.content
|
| 301 |
+
continue
|
| 302 |
+
except Exception:
|
| 303 |
+
logger.warning(
|
| 304 |
+
"LLM prune failed for %s, falling back to deterministic",
|
| 305 |
+
tool_name,
|
| 306 |
+
)
|
| 307 |
+
# fall through to deterministic pruning below
|
| 308 |
+
|
| 309 |
+
# --- Deterministic pruning ---
|
| 310 |
preview = content[:80]
|
| 311 |
total = len(content)
|
| 312 |
|
| 313 |
+
if tool_name == "bash":
|
|
|
|
|
|
|
|
|
|
| 314 |
exit_code_part = ""
|
| 315 |
if "exit_code" in content[:200]:
|
| 316 |
for line in content[:200].splitlines():
|
| 317 |
if "exit_code" in line:
|
| 318 |
+
exit_code_part = "exit_code visible if present, "
|
| 319 |
break
|
| 320 |
summary = f"[bash: {exit_code_part}{preview}... ({total} chars)]"
|
| 321 |
else:
|
|
|
|
| 327 |
self, model_name: str, tool_specs: list[dict] | None = None
|
| 328 |
) -> None:
|
| 329 |
"""Remove old messages to keep history under target size"""
|
| 330 |
+
await self.prune_old_tool_outputs(model_name=model_name)
|
| 331 |
|
| 332 |
if (self.context_length <= self.max_context) or not self.items:
|
| 333 |
return
|
agent/core/agent_loop.py
CHANGED
|
@@ -199,7 +199,9 @@ def _is_transient_error(error: Exception) -> bool:
|
|
| 199 |
|
| 200 |
async def _compact_and_notify(session: Session) -> None:
|
| 201 |
"""Run compaction and send event if context was reduced."""
|
| 202 |
-
session.context_manager.prune_old_tool_outputs(
|
|
|
|
|
|
|
| 203 |
old_length = session.context_manager.context_length
|
| 204 |
tool_specs = session.tool_router.get_tool_specs_for_llm()
|
| 205 |
await session.context_manager.compact(
|
|
|
|
| 199 |
|
| 200 |
async def _compact_and_notify(session: Session) -> None:
|
| 201 |
"""Run compaction and send event if context was reduced."""
|
| 202 |
+
await session.context_manager.prune_old_tool_outputs(
|
| 203 |
+
model_name=session.config.model_name,
|
| 204 |
+
)
|
| 205 |
old_length = session.context_manager.context_length
|
| 206 |
tool_specs = session.tool_router.get_tool_specs_for_llm()
|
| 207 |
await session.context_manager.compact(
|