"""Cache-Aware Prompt Layout: Optimize prompt structure for prefix-cache reuse.""" from typing import Dict, List, Tuple from dataclasses import dataclass @dataclass class PromptLayout: prefix_content: str suffix_content: str prefix_tokens: int suffix_tokens: int cache_boundary: int # token position of cache boundary stable_sources: List[str] dynamic_sources: List[str] CACHE_STABLE_SOURCES = {"system_rules", "tool_descriptions", "user_preferences"} CACHE_DYNAMIC_SOURCES = {"recent_messages", "task_plan", "retrieved_docs", "artifacts"} class CacheAwareLayout: def __init__(self, session_id: str = None): self.session_id = session_id self._prev_prefix_hash = None self.cache_hits = 0 self.cache_misses = 0 self.total_prefix_tokens = 0 self.total_suffix_tokens = 0 def layout(self, sources: Dict[str, str], context_budget) -> PromptLayout: prefix_parts = [] suffix_parts = [] stable = [] dynamic = [] prefix_tokens = 0 suffix_tokens = 0 for source_name, content in sources.items(): token_est = len(content) // 4 # rough estimate if source_name in context_budget.cache_prefix: prefix_parts.append(f"# {source_name}\n{content}") prefix_tokens += token_est stable.append(source_name) elif source_name in context_budget.dynamic_suffix: suffix_parts.append(f"# {source_name}\n{content}") suffix_tokens += token_est dynamic.append(source_name) elif source_name in context_budget.keep_exact: prefix_parts.append(f"# {source_name}\n{content}") prefix_tokens += token_est stable.append(source_name) else: suffix_parts.append(f"# {source_name}\n{content}") suffix_tokens += token_est dynamic.append(source_name) prefix = "\n\n".join(prefix_parts) suffix = "\n\n".join(suffix_parts) # Check cache hit import hashlib prefix_hash = hashlib.md5(prefix.encode()).hexdigest() if self._prev_prefix_hash == prefix_hash: self.cache_hits += 1 else: self.cache_misses += 1 self._prev_prefix_hash = prefix_hash self.total_prefix_tokens += prefix_tokens self.total_suffix_tokens += suffix_tokens return PromptLayout( prefix_content=prefix, suffix_content=suffix, prefix_tokens=prefix_tokens, suffix_tokens=suffix_tokens, cache_boundary=prefix_tokens, stable_sources=stable, dynamic_sources=dynamic, ) def stats(self) -> Dict: total = self.cache_hits + self.cache_misses return { "cache_hit_rate": self.cache_hits / max(total, 1), "total_cache_hits": self.cache_hits, "total_cache_misses": self.cache_misses, "avg_prefix_tokens": self.total_prefix_tokens / max(total, 1), "avg_suffix_tokens": self.total_suffix_tokens / max(total, 1), }