XciD7 commited on
Commit
5d357ba
·
unverified ·
1 Parent(s): e2552e8

feat: enable Anthropic prompt caching on system prompt and tools (#69)

Browse files

* feat: enable Anthropic prompt caching on system prompt and tools

Mark the rendered system prompt and the tool block with cache_control
breakpoints when calling Anthropic models. The static prefix (~4-5K
tokens of system prompt + 15+ tool definitions) was being re-billed at
full input rate on every turn, every retry, and every research
sub-agent iteration (up to 60 per task).

With ephemeral cache breakpoints, subsequent turns within the 5-minute
TTL are billed at cache-read pricing (~10% of input cost). Expected
savings: 40-50% input tokens on multi-turn conversations, 60-80% on
research sub-agent loops.

Caching is GA in the Anthropic API and natively supported by litellm
1.83+ via cache_control blocks (no beta header required). Non-Anthropic
models (HF router, OpenAI) are passed through unchanged.

The helper does not mutate the caller's message list or tool list, so
the persisted ContextManager.items history stays in its original
string-content form.

* refactor: hoist prompt_caching imports to module level, drop cached_ prefix

agent/context_manager/manager.py CHANGED
@@ -13,6 +13,8 @@ import yaml
13
  from jinja2 import Template
14
  from litellm import Message, acompletion
15
 
 
 
16
  logger = logging.getLogger(__name__)
17
 
18
  _HF_WHOAMI_URL = "https://huggingface.co/api/whoami-v2"
@@ -114,6 +116,9 @@ async def summarize_messages(
114
 
115
  prompt_messages = list(messages) + [Message(role="user", content=prompt)]
116
  llm_params = _resolve_llm_params(model_name, hf_token, reasoning_effort="high")
 
 
 
117
  response = await acompletion(
118
  messages=prompt_messages,
119
  max_completion_tokens=max_tokens,
 
13
  from jinja2 import Template
14
  from litellm import Message, acompletion
15
 
16
+ from agent.core.prompt_caching import with_prompt_caching
17
+
18
  logger = logging.getLogger(__name__)
19
 
20
  _HF_WHOAMI_URL = "https://huggingface.co/api/whoami-v2"
 
116
 
117
  prompt_messages = list(messages) + [Message(role="user", content=prompt)]
118
  llm_params = _resolve_llm_params(model_name, hf_token, reasoning_effort="high")
119
+ prompt_messages, tool_specs = with_prompt_caching(
120
+ prompt_messages, tool_specs, llm_params.get("model")
121
+ )
122
  response = await acompletion(
123
  messages=prompt_messages,
124
  max_completion_tokens=max_tokens,
agent/core/agent_loop.py CHANGED
@@ -14,6 +14,7 @@ from litellm.exceptions import ContextWindowExceededError
14
  from agent.config import Config
15
  from agent.core.doom_loop import check_for_doom_loop
16
  from agent.core.llm_params import _resolve_llm_params
 
17
  from agent.core.session import Event, OpType, Session
18
  from agent.core.tools import ToolRouter
19
  from agent.tools.jobs_tool import CPU_FLAVORS
@@ -296,6 +297,7 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
296
  """Call the LLM with streaming, emitting assistant_chunk events."""
297
  response = None
298
  _healed_effort = False # one-shot safety net per call
 
299
  for _llm_attempt in range(_MAX_LLM_RETRIES):
300
  try:
301
  response = await acompletion(
@@ -390,6 +392,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
390
  """Call the LLM without streaming, emit assistant_message at the end."""
391
  response = None
392
  _healed_effort = False
 
393
  for _llm_attempt in range(_MAX_LLM_RETRIES):
394
  try:
395
  response = await acompletion(
 
14
  from agent.config import Config
15
  from agent.core.doom_loop import check_for_doom_loop
16
  from agent.core.llm_params import _resolve_llm_params
17
+ from agent.core.prompt_caching import with_prompt_caching
18
  from agent.core.session import Event, OpType, Session
19
  from agent.core.tools import ToolRouter
20
  from agent.tools.jobs_tool import CPU_FLAVORS
 
297
  """Call the LLM with streaming, emitting assistant_chunk events."""
298
  response = None
299
  _healed_effort = False # one-shot safety net per call
300
+ messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
301
  for _llm_attempt in range(_MAX_LLM_RETRIES):
302
  try:
303
  response = await acompletion(
 
392
  """Call the LLM without streaming, emit assistant_message at the end."""
393
  response = None
394
  _healed_effort = False
395
+ messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
396
  for _llm_attempt in range(_MAX_LLM_RETRIES):
397
  try:
398
  response = await acompletion(
agent/core/prompt_caching.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anthropic prompt caching breakpoints for outgoing LLM requests.
2
+
3
+ Caching is GA on Anthropic's API and natively supported by litellm >=1.83
4
+ via ``cache_control`` blocks. We apply two breakpoints (out of 4 allowed):
5
+
6
+ 1. The tool block — caches all tool definitions as a single prefix.
7
+ 2. The system message — caches the rendered system prompt.
8
+
9
+ Together these cover the ~4-5K static tokens that were being re-billed on
10
+ every turn. Subsequent turns within the 5-minute TTL hit cache_read pricing
11
+ (~10% of input cost) instead of full input.
12
+
13
+ Non-Anthropic models (HF router, OpenAI) are passed through unchanged.
14
+ """
15
+
16
+ from typing import Any
17
+
18
+
19
+ def with_prompt_caching(
20
+ messages: list[Any],
21
+ tools: list[dict] | None,
22
+ model_name: str | None,
23
+ ) -> tuple[list[Any], list[dict] | None]:
24
+ """Return (messages, tools) with cache_control breakpoints for Anthropic.
25
+
26
+ No-op for non-Anthropic models. Original objects are not mutated; a fresh
27
+ list with replaced first message and last tool is returned, so callers
28
+ that share the underlying ``ContextManager.items`` list don't see their
29
+ persisted history rewritten.
30
+ """
31
+ if not model_name or not model_name.startswith("anthropic/"):
32
+ return messages, tools
33
+
34
+ if tools:
35
+ new_tools = list(tools)
36
+ last = dict(new_tools[-1])
37
+ last["cache_control"] = {"type": "ephemeral"}
38
+ new_tools[-1] = last
39
+ tools = new_tools
40
+
41
+ if messages:
42
+ first = messages[0]
43
+ role = first.get("role") if isinstance(first, dict) else getattr(first, "role", None)
44
+ if role == "system":
45
+ content = (
46
+ first.get("content")
47
+ if isinstance(first, dict)
48
+ else getattr(first, "content", None)
49
+ )
50
+ if isinstance(content, str) and content:
51
+ cached_block = [{
52
+ "type": "text",
53
+ "text": content,
54
+ "cache_control": {"type": "ephemeral"},
55
+ }]
56
+ new_first = {"role": "system", "content": cached_block}
57
+ messages = [new_first] + list(messages[1:])
58
+
59
+ return messages, tools
agent/tools/research_tool.py CHANGED
@@ -15,6 +15,7 @@ from litellm import Message, acompletion
15
 
16
  from agent.core.doom_loop import check_for_doom_loop
17
  from agent.core.llm_params import _resolve_llm_params
 
18
  from agent.core.session import Event
19
 
20
  logger = logging.getLogger(__name__)
@@ -323,8 +324,9 @@ async def research_handler(
323
  ),
324
  ))
325
  try:
 
326
  response = await acompletion(
327
- messages=messages,
328
  tools=None, # no tools — force text response
329
  stream=False,
330
  timeout=120,
@@ -348,9 +350,12 @@ async def research_handler(
348
  ))
349
 
350
  try:
 
 
 
351
  response = await acompletion(
352
- messages=messages,
353
- tools=tool_specs if tool_specs else None,
354
  tool_choice="auto",
355
  stream=False,
356
  timeout=120,
@@ -446,8 +451,9 @@ async def research_handler(
446
  ),
447
  ))
448
  try:
 
449
  response = await acompletion(
450
- messages=messages,
451
  tools=None,
452
  stream=False,
453
  timeout=120,
 
15
 
16
  from agent.core.doom_loop import check_for_doom_loop
17
  from agent.core.llm_params import _resolve_llm_params
18
+ from agent.core.prompt_caching import with_prompt_caching
19
  from agent.core.session import Event
20
 
21
  logger = logging.getLogger(__name__)
 
324
  ),
325
  ))
326
  try:
327
+ _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model"))
328
  response = await acompletion(
329
+ messages=_msgs,
330
  tools=None, # no tools — force text response
331
  stream=False,
332
  timeout=120,
 
350
  ))
351
 
352
  try:
353
+ _msgs, _tools = with_prompt_caching(
354
+ messages, tool_specs if tool_specs else None, llm_params.get("model")
355
+ )
356
  response = await acompletion(
357
+ messages=_msgs,
358
+ tools=_tools,
359
  tool_choice="auto",
360
  stream=False,
361
  timeout=120,
 
451
  ),
452
  ))
453
  try:
454
+ _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model"))
455
  response = await acompletion(
456
+ messages=_msgs,
457
  tools=None,
458
  stream=False,
459
  timeout=120,