agent-cost-optimizer / aco /cache_layout.py
narcolepticchicken's picture
Upload aco/cache_layout.py with huggingface_hub
eaafe86 verified
"""Cache-Aware Prompt Layout: Optimize prompt structure for prefix-cache reuse."""
from typing import Dict, List, Tuple
from dataclasses import dataclass
@dataclass
class PromptLayout:
prefix_content: str
suffix_content: str
prefix_tokens: int
suffix_tokens: int
cache_boundary: int # token position of cache boundary
stable_sources: List[str]
dynamic_sources: List[str]
CACHE_STABLE_SOURCES = {"system_rules", "tool_descriptions", "user_preferences"}
CACHE_DYNAMIC_SOURCES = {"recent_messages", "task_plan", "retrieved_docs", "artifacts"}
class CacheAwareLayout:
def __init__(self, session_id: str = None):
self.session_id = session_id
self._prev_prefix_hash = None
self.cache_hits = 0
self.cache_misses = 0
self.total_prefix_tokens = 0
self.total_suffix_tokens = 0
def layout(self, sources: Dict[str, str], context_budget) -> PromptLayout:
prefix_parts = []
suffix_parts = []
stable = []
dynamic = []
prefix_tokens = 0
suffix_tokens = 0
for source_name, content in sources.items():
token_est = len(content) // 4 # rough estimate
if source_name in context_budget.cache_prefix:
prefix_parts.append(f"# {source_name}\n{content}")
prefix_tokens += token_est
stable.append(source_name)
elif source_name in context_budget.dynamic_suffix:
suffix_parts.append(f"# {source_name}\n{content}")
suffix_tokens += token_est
dynamic.append(source_name)
elif source_name in context_budget.keep_exact:
prefix_parts.append(f"# {source_name}\n{content}")
prefix_tokens += token_est
stable.append(source_name)
else:
suffix_parts.append(f"# {source_name}\n{content}")
suffix_tokens += token_est
dynamic.append(source_name)
prefix = "\n\n".join(prefix_parts)
suffix = "\n\n".join(suffix_parts)
# Check cache hit
import hashlib
prefix_hash = hashlib.md5(prefix.encode()).hexdigest()
if self._prev_prefix_hash == prefix_hash:
self.cache_hits += 1
else:
self.cache_misses += 1
self._prev_prefix_hash = prefix_hash
self.total_prefix_tokens += prefix_tokens
self.total_suffix_tokens += suffix_tokens
return PromptLayout(
prefix_content=prefix,
suffix_content=suffix,
prefix_tokens=prefix_tokens,
suffix_tokens=suffix_tokens,
cache_boundary=prefix_tokens,
stable_sources=stable,
dynamic_sources=dynamic,
)
def stats(self) -> Dict:
total = self.cache_hits + self.cache_misses
return {
"cache_hit_rate": self.cache_hits / max(total, 1),
"total_cache_hits": self.cache_hits,
"total_cache_misses": self.cache_misses,
"avg_prefix_tokens": self.total_prefix_tokens / max(total, 1),
"avg_suffix_tokens": self.total_suffix_tokens / max(total, 1),
}