| """Cache-Aware Prompt Layout - Module 5. |
| |
| Optimizes prompt/context structure for prefix-cache reuse. |
| |
| Strategy: |
| - Keep stable rules in the prefix (system rules, tool descriptions, user preferences) |
| - Keep tool descriptions stable |
| - Move dynamic content to the suffix (user message, retrieved docs, recent trace, artifacts) |
| - Avoid injecting timestamps/random metadata above cache boundary |
| - Preserve sticky provider/session routing where useful |
| |
| Metrics: |
| - cache hit rate |
| - warm-cache cost |
| - cold-cache cost |
| - latency |
| - context staleness failures |
| """ |
|
|
| from typing import Dict, List, Tuple, Optional, Any |
| from dataclasses import dataclass |
|
|
| from .config import ACOConfig |
|
|
|
|
| @dataclass |
| class PromptLayout: |
| prefix: str |
| suffix: str |
| prefix_tokens: int |
| suffix_tokens: int |
| cache_boundary_token: int |
| estimated_cold_cost: float |
| estimated_warm_cost: float |
| cache_discount: float |
|
|
|
|
| class CacheAwarePromptLayout: |
| """Lays out prompts to maximize prefix cache reuse.""" |
|
|
| |
| PREFIX_CONTENT_TYPES = [ |
| "system_rules", |
| "tool_descriptions", |
| "user_preferences", |
| "static_examples", |
| "persona_definition", |
| ] |
|
|
| |
| SUFFIX_CONTENT_TYPES = [ |
| "user_message", |
| "retrieved_docs", |
| "recent_trace", |
| "artifacts", |
| "timestamp", |
| "session_id", |
| "dynamic_examples", |
| "conversation_history", |
| ] |
|
|
| def __init__(self, config: Optional[ACOConfig] = None): |
| self.config = config or ACOConfig() |
| self.cache_stats = { |
| "cold_runs": 0, |
| "warm_runs": 0, |
| "prefix_tokens_avg": 0, |
| "cache_hit_rate": 0.0, |
| "staleness_failures": 0, |
| } |
|
|
| def layout( |
| self, |
| content_pieces: Dict[str, str], |
| cost_per_1k_input: float = 0.01, |
| cache_discount_rate: float = 0.5, |
| ) -> PromptLayout: |
| """Partition content into prefix (cacheable) and suffix (dynamic).""" |
| |
| prefix_pieces = [] |
| suffix_pieces = [] |
| |
| for key, text in content_pieces.items(): |
| if self._is_prefix_content(key): |
| prefix_pieces.append(text) |
| else: |
| suffix_pieces.append(text) |
| |
| |
| prefix = "\n\n".join(prefix_pieces) |
| suffix = "\n\n".join(suffix_pieces) |
| |
| |
| prefix_tokens = len(prefix) // 4 |
| suffix_tokens = len(suffix) // 4 |
| |
| |
| estimated_cold_cost = ((prefix_tokens + suffix_tokens) / 1000) * cost_per_1k_input |
| estimated_warm_cost = ((suffix_tokens + prefix_tokens * (1 - cache_discount_rate)) / 1000) * cost_per_1k_input |
| cache_discount = ((prefix_tokens * cache_discount_rate) / 1000) * cost_per_1k_input |
| |
| return PromptLayout( |
| prefix=prefix, |
| suffix=suffix, |
| prefix_tokens=prefix_tokens, |
| suffix_tokens=suffix_tokens, |
| cache_boundary_token=prefix_tokens, |
| estimated_cold_cost=estimated_cold_cost, |
| estimated_warm_cost=estimated_warm_cost, |
| cache_discount=cache_discount, |
| ) |
|
|
| def _is_prefix_content(self, key: str) -> bool: |
| """Determine if a content key belongs in the prefix.""" |
| |
| if key in self.PREFIX_CONTENT_TYPES: |
| return True |
| if key in self.SUFFIX_CONTENT_TYPES: |
| return False |
| |
| |
| if any(kw in key.lower() for kw in ["system", "static", "rule", "persona", "schema", "format"]): |
| return True |
| if any(kw in key.lower() for kw in ["user_", "dynamic", "current", "live", "now", "timestamp"]): |
| return False |
| |
| |
| return True |
|
|
| def optimize_for_provider( |
| self, |
| layout: PromptLayout, |
| provider: str, |
| ) -> PromptLayout: |
| """Provider-specific cache layout optimizations.""" |
| |
| provider = provider.lower() |
| |
| if "anthropic" in provider: |
| |
| |
| return layout |
| |
| elif "openai" in provider: |
| |
| |
| return layout |
| |
| elif "gemini" in provider: |
| |
| return layout |
| |
| elif "deepseek" in provider: |
| |
| return layout |
| |
| return layout |
|
|
| def measure_hit_rate(self, prefix_tokens: int, cache_hit: bool) -> None: |
| """Update cache statistics.""" |
| if cache_hit: |
| self.cache_stats["warm_runs"] += 1 |
| else: |
| self.cache_stats["cold_runs"] += 1 |
| |
| total = self.cache_stats["warm_runs"] + self.cache_stats["cold_runs"] |
| self.cache_stats["cache_hit_rate"] = self.cache_stats["warm_runs"] / total if total > 0 else 0.0 |
| |
| |
| n = total |
| self.cache_stats["prefix_tokens_avg"] = ( |
| (self.cache_stats["prefix_tokens_avg"] * (n - 1) + prefix_tokens) / n |
| ) |
|
|
| def report(self) -> Dict[str, Any]: |
| """Generate cache performance report.""" |
| total = self.cache_stats["warm_runs"] + self.cache_stats["cold_runs"] |
| return { |
| "total_runs": total, |
| "warm_runs": self.cache_stats["warm_runs"], |
| "cold_runs": self.cache_stats["cold_runs"], |
| "cache_hit_rate": self.cache_stats["cache_hit_rate"], |
| "avg_prefix_tokens": self.cache_stats["prefix_tokens_avg"], |
| "staleness_failures": self.cache_stats["staleness_failures"], |
| "estimated_cost_saved": self.cache_stats["warm_runs"] * self.cache_stats.get("avg_cache_discount", 0.0), |
| } |
|
|