agent-cost-optimizer / aco /cache_layout.py
narcolepticchicken's picture
Upload aco/cache_layout.py
07f0bb4 verified
raw
history blame
6.25 kB
"""Cache-Aware Prompt Layout - Module 5.
Optimizes prompt/context structure for prefix-cache reuse.
Strategy:
- Keep stable rules in the prefix (system rules, tool descriptions, user preferences)
- Keep tool descriptions stable
- Move dynamic content to the suffix (user message, retrieved docs, recent trace, artifacts)
- Avoid injecting timestamps/random metadata above cache boundary
- Preserve sticky provider/session routing where useful
Metrics:
- cache hit rate
- warm-cache cost
- cold-cache cost
- latency
- context staleness failures
"""
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from .config import ACOConfig
@dataclass
class PromptLayout:
prefix: str # Stable, cacheable content
suffix: str # Dynamic content per turn
prefix_tokens: int
suffix_tokens: int
cache_boundary_token: int
estimated_cold_cost: float
estimated_warm_cost: float
cache_discount: float
class CacheAwarePromptLayout:
"""Lays out prompts to maximize prefix cache reuse."""
# Content types that should stay in prefix
PREFIX_CONTENT_TYPES = [
"system_rules",
"tool_descriptions",
"user_preferences",
"static_examples",
"persona_definition",
]
# Content types that should be in suffix
SUFFIX_CONTENT_TYPES = [
"user_message",
"retrieved_docs",
"recent_trace",
"artifacts",
"timestamp",
"session_id",
"dynamic_examples",
"conversation_history",
]
def __init__(self, config: Optional[ACOConfig] = None):
self.config = config or ACOConfig()
self.cache_stats = {
"cold_runs": 0,
"warm_runs": 0,
"prefix_tokens_avg": 0,
"cache_hit_rate": 0.0,
"staleness_failures": 0,
}
def layout(
self,
content_pieces: Dict[str, str],
cost_per_1k_input: float = 0.01,
cache_discount_rate: float = 0.5,
) -> PromptLayout:
"""Partition content into prefix (cacheable) and suffix (dynamic)."""
prefix_pieces = []
suffix_pieces = []
for key, text in content_pieces.items():
if self._is_prefix_content(key):
prefix_pieces.append(text)
else:
suffix_pieces.append(text)
# Sort prefix: most stable first
prefix = "\n\n".join(prefix_pieces)
suffix = "\n\n".join(suffix_pieces)
# Token estimation (rough: 1 token ~ 4 chars for English)
prefix_tokens = len(prefix) // 4
suffix_tokens = len(suffix) // 4
# Costs
estimated_cold_cost = ((prefix_tokens + suffix_tokens) / 1000) * cost_per_1k_input
estimated_warm_cost = ((suffix_tokens + prefix_tokens * (1 - cache_discount_rate)) / 1000) * cost_per_1k_input
cache_discount = ((prefix_tokens * cache_discount_rate) / 1000) * cost_per_1k_input
return PromptLayout(
prefix=prefix,
suffix=suffix,
prefix_tokens=prefix_tokens,
suffix_tokens=suffix_tokens,
cache_boundary_token=prefix_tokens,
estimated_cold_cost=estimated_cold_cost,
estimated_warm_cost=estimated_warm_cost,
cache_discount=cache_discount,
)
def _is_prefix_content(self, key: str) -> bool:
"""Determine if a content key belongs in the prefix."""
# Direct matches
if key in self.PREFIX_CONTENT_TYPES:
return True
if key in self.SUFFIX_CONTENT_TYPES:
return False
# Pattern matching
if any(kw in key.lower() for kw in ["system", "static", "rule", "persona", "schema", "format"]):
return True
if any(kw in key.lower() for kw in ["user_", "dynamic", "current", "live", "now", "timestamp"]):
return False
# Default: prefix if name suggests stability
return True
def optimize_for_provider(
self,
layout: PromptLayout,
provider: str,
) -> PromptLayout:
"""Provider-specific cache layout optimizations."""
provider = provider.lower()
if "anthropic" in provider:
# Claude has system prompts that are automatically cached
# Keep system content separate
return layout
elif "openai" in provider:
# OpenAI has prefix caching on system + first user message
# Ensure system content is at the very top
return layout
elif "gemini" in provider:
# Gemini has context caching for repeated contexts
return layout
elif "deepseek" in provider:
# DeepSeek has cache hit discounts
return layout
return layout
def measure_hit_rate(self, prefix_tokens: int, cache_hit: bool) -> None:
"""Update cache statistics."""
if cache_hit:
self.cache_stats["warm_runs"] += 1
else:
self.cache_stats["cold_runs"] += 1
total = self.cache_stats["warm_runs"] + self.cache_stats["cold_runs"]
self.cache_stats["cache_hit_rate"] = self.cache_stats["warm_runs"] / total if total > 0 else 0.0
# Running average of prefix tokens
n = total
self.cache_stats["prefix_tokens_avg"] = (
(self.cache_stats["prefix_tokens_avg"] * (n - 1) + prefix_tokens) / n
)
def report(self) -> Dict[str, Any]:
"""Generate cache performance report."""
total = self.cache_stats["warm_runs"] + self.cache_stats["cold_runs"]
return {
"total_runs": total,
"warm_runs": self.cache_stats["warm_runs"],
"cold_runs": self.cache_stats["cold_runs"],
"cache_hit_rate": self.cache_stats["cache_hit_rate"],
"avg_prefix_tokens": self.cache_stats["prefix_tokens_avg"],
"staleness_failures": self.cache_stats["staleness_failures"],
"estimated_cost_saved": self.cache_stats["warm_runs"] * self.cache_stats.get("avg_cache_discount", 0.0),
}