"""Cache-Aware Prompt Layout - Module 5. Optimizes prompt/context structure for prefix-cache reuse. Strategy: - Keep stable rules in the prefix (system rules, tool descriptions, user preferences) - Keep tool descriptions stable - Move dynamic content to the suffix (user message, retrieved docs, recent trace, artifacts) - Avoid injecting timestamps/random metadata above cache boundary - Preserve sticky provider/session routing where useful Metrics: - cache hit rate - warm-cache cost - cold-cache cost - latency - context staleness failures """ from typing import Dict, List, Tuple, Optional, Any from dataclasses import dataclass from .config import ACOConfig @dataclass class PromptLayout: prefix: str # Stable, cacheable content suffix: str # Dynamic content per turn prefix_tokens: int suffix_tokens: int cache_boundary_token: int estimated_cold_cost: float estimated_warm_cost: float cache_discount: float class CacheAwarePromptLayout: """Lays out prompts to maximize prefix cache reuse.""" # Content types that should stay in prefix PREFIX_CONTENT_TYPES = [ "system_rules", "tool_descriptions", "user_preferences", "static_examples", "persona_definition", ] # Content types that should be in suffix SUFFIX_CONTENT_TYPES = [ "user_message", "retrieved_docs", "recent_trace", "artifacts", "timestamp", "session_id", "dynamic_examples", "conversation_history", ] def __init__(self, config: Optional[ACOConfig] = None): self.config = config or ACOConfig() self.cache_stats = { "cold_runs": 0, "warm_runs": 0, "prefix_tokens_avg": 0, "cache_hit_rate": 0.0, "staleness_failures": 0, } def layout( self, content_pieces: Dict[str, str], cost_per_1k_input: float = 0.01, cache_discount_rate: float = 0.5, ) -> PromptLayout: """Partition content into prefix (cacheable) and suffix (dynamic).""" prefix_pieces = [] suffix_pieces = [] for key, text in content_pieces.items(): if self._is_prefix_content(key): prefix_pieces.append(text) else: suffix_pieces.append(text) # Sort prefix: most stable first prefix = "\n\n".join(prefix_pieces) suffix = "\n\n".join(suffix_pieces) # Token estimation (rough: 1 token ~ 4 chars for English) prefix_tokens = len(prefix) // 4 suffix_tokens = len(suffix) // 4 # Costs estimated_cold_cost = ((prefix_tokens + suffix_tokens) / 1000) * cost_per_1k_input estimated_warm_cost = ((suffix_tokens + prefix_tokens * (1 - cache_discount_rate)) / 1000) * cost_per_1k_input cache_discount = ((prefix_tokens * cache_discount_rate) / 1000) * cost_per_1k_input return PromptLayout( prefix=prefix, suffix=suffix, prefix_tokens=prefix_tokens, suffix_tokens=suffix_tokens, cache_boundary_token=prefix_tokens, estimated_cold_cost=estimated_cold_cost, estimated_warm_cost=estimated_warm_cost, cache_discount=cache_discount, ) def _is_prefix_content(self, key: str) -> bool: """Determine if a content key belongs in the prefix.""" # Direct matches if key in self.PREFIX_CONTENT_TYPES: return True if key in self.SUFFIX_CONTENT_TYPES: return False # Pattern matching if any(kw in key.lower() for kw in ["system", "static", "rule", "persona", "schema", "format"]): return True if any(kw in key.lower() for kw in ["user_", "dynamic", "current", "live", "now", "timestamp"]): return False # Default: prefix if name suggests stability return True def optimize_for_provider( self, layout: PromptLayout, provider: str, ) -> PromptLayout: """Provider-specific cache layout optimizations.""" provider = provider.lower() if "anthropic" in provider: # Claude has system prompts that are automatically cached # Keep system content separate return layout elif "openai" in provider: # OpenAI has prefix caching on system + first user message # Ensure system content is at the very top return layout elif "gemini" in provider: # Gemini has context caching for repeated contexts return layout elif "deepseek" in provider: # DeepSeek has cache hit discounts return layout return layout def measure_hit_rate(self, prefix_tokens: int, cache_hit: bool) -> None: """Update cache statistics.""" if cache_hit: self.cache_stats["warm_runs"] += 1 else: self.cache_stats["cold_runs"] += 1 total = self.cache_stats["warm_runs"] + self.cache_stats["cold_runs"] self.cache_stats["cache_hit_rate"] = self.cache_stats["warm_runs"] / total if total > 0 else 0.0 # Running average of prefix tokens n = total self.cache_stats["prefix_tokens_avg"] = ( (self.cache_stats["prefix_tokens_avg"] * (n - 1) + prefix_tokens) / n ) def report(self) -> Dict[str, Any]: """Generate cache performance report.""" total = self.cache_stats["warm_runs"] + self.cache_stats["cold_runs"] return { "total_runs": total, "warm_runs": self.cache_stats["warm_runs"], "cold_runs": self.cache_stats["cold_runs"], "cache_hit_rate": self.cache_stats["cache_hit_rate"], "avg_prefix_tokens": self.cache_stats["prefix_tokens_avg"], "staleness_failures": self.cache_stats["staleness_failures"], "estimated_cost_saved": self.cache_stats["warm_runs"] * self.cache_stats.get("avg_cache_discount", 0.0), }