narcolepticchicken commited on
Commit
07f0bb4
·
verified ·
1 Parent(s): 581261a

Upload aco/cache_layout.py

Browse files
Files changed (1) hide show
  1. aco/cache_layout.py +186 -0
aco/cache_layout.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cache-Aware Prompt Layout - Module 5.
2
+
3
+ Optimizes prompt/context structure for prefix-cache reuse.
4
+
5
+ Strategy:
6
+ - Keep stable rules in the prefix (system rules, tool descriptions, user preferences)
7
+ - Keep tool descriptions stable
8
+ - Move dynamic content to the suffix (user message, retrieved docs, recent trace, artifacts)
9
+ - Avoid injecting timestamps/random metadata above cache boundary
10
+ - Preserve sticky provider/session routing where useful
11
+
12
+ Metrics:
13
+ - cache hit rate
14
+ - warm-cache cost
15
+ - cold-cache cost
16
+ - latency
17
+ - context staleness failures
18
+ """
19
+
20
+ from typing import Dict, List, Tuple, Optional, Any
21
+ from dataclasses import dataclass
22
+
23
+ from .config import ACOConfig
24
+
25
+
26
+ @dataclass
27
+ class PromptLayout:
28
+ prefix: str # Stable, cacheable content
29
+ suffix: str # Dynamic content per turn
30
+ prefix_tokens: int
31
+ suffix_tokens: int
32
+ cache_boundary_token: int
33
+ estimated_cold_cost: float
34
+ estimated_warm_cost: float
35
+ cache_discount: float
36
+
37
+
38
+ class CacheAwarePromptLayout:
39
+ """Lays out prompts to maximize prefix cache reuse."""
40
+
41
+ # Content types that should stay in prefix
42
+ PREFIX_CONTENT_TYPES = [
43
+ "system_rules",
44
+ "tool_descriptions",
45
+ "user_preferences",
46
+ "static_examples",
47
+ "persona_definition",
48
+ ]
49
+
50
+ # Content types that should be in suffix
51
+ SUFFIX_CONTENT_TYPES = [
52
+ "user_message",
53
+ "retrieved_docs",
54
+ "recent_trace",
55
+ "artifacts",
56
+ "timestamp",
57
+ "session_id",
58
+ "dynamic_examples",
59
+ "conversation_history",
60
+ ]
61
+
62
+ def __init__(self, config: Optional[ACOConfig] = None):
63
+ self.config = config or ACOConfig()
64
+ self.cache_stats = {
65
+ "cold_runs": 0,
66
+ "warm_runs": 0,
67
+ "prefix_tokens_avg": 0,
68
+ "cache_hit_rate": 0.0,
69
+ "staleness_failures": 0,
70
+ }
71
+
72
+ def layout(
73
+ self,
74
+ content_pieces: Dict[str, str],
75
+ cost_per_1k_input: float = 0.01,
76
+ cache_discount_rate: float = 0.5,
77
+ ) -> PromptLayout:
78
+ """Partition content into prefix (cacheable) and suffix (dynamic)."""
79
+
80
+ prefix_pieces = []
81
+ suffix_pieces = []
82
+
83
+ for key, text in content_pieces.items():
84
+ if self._is_prefix_content(key):
85
+ prefix_pieces.append(text)
86
+ else:
87
+ suffix_pieces.append(text)
88
+
89
+ # Sort prefix: most stable first
90
+ prefix = "\n\n".join(prefix_pieces)
91
+ suffix = "\n\n".join(suffix_pieces)
92
+
93
+ # Token estimation (rough: 1 token ~ 4 chars for English)
94
+ prefix_tokens = len(prefix) // 4
95
+ suffix_tokens = len(suffix) // 4
96
+
97
+ # Costs
98
+ estimated_cold_cost = ((prefix_tokens + suffix_tokens) / 1000) * cost_per_1k_input
99
+ estimated_warm_cost = ((suffix_tokens + prefix_tokens * (1 - cache_discount_rate)) / 1000) * cost_per_1k_input
100
+ cache_discount = ((prefix_tokens * cache_discount_rate) / 1000) * cost_per_1k_input
101
+
102
+ return PromptLayout(
103
+ prefix=prefix,
104
+ suffix=suffix,
105
+ prefix_tokens=prefix_tokens,
106
+ suffix_tokens=suffix_tokens,
107
+ cache_boundary_token=prefix_tokens,
108
+ estimated_cold_cost=estimated_cold_cost,
109
+ estimated_warm_cost=estimated_warm_cost,
110
+ cache_discount=cache_discount,
111
+ )
112
+
113
+ def _is_prefix_content(self, key: str) -> bool:
114
+ """Determine if a content key belongs in the prefix."""
115
+ # Direct matches
116
+ if key in self.PREFIX_CONTENT_TYPES:
117
+ return True
118
+ if key in self.SUFFIX_CONTENT_TYPES:
119
+ return False
120
+
121
+ # Pattern matching
122
+ if any(kw in key.lower() for kw in ["system", "static", "rule", "persona", "schema", "format"]):
123
+ return True
124
+ if any(kw in key.lower() for kw in ["user_", "dynamic", "current", "live", "now", "timestamp"]):
125
+ return False
126
+
127
+ # Default: prefix if name suggests stability
128
+ return True
129
+
130
+ def optimize_for_provider(
131
+ self,
132
+ layout: PromptLayout,
133
+ provider: str,
134
+ ) -> PromptLayout:
135
+ """Provider-specific cache layout optimizations."""
136
+
137
+ provider = provider.lower()
138
+
139
+ if "anthropic" in provider:
140
+ # Claude has system prompts that are automatically cached
141
+ # Keep system content separate
142
+ return layout
143
+
144
+ elif "openai" in provider:
145
+ # OpenAI has prefix caching on system + first user message
146
+ # Ensure system content is at the very top
147
+ return layout
148
+
149
+ elif "gemini" in provider:
150
+ # Gemini has context caching for repeated contexts
151
+ return layout
152
+
153
+ elif "deepseek" in provider:
154
+ # DeepSeek has cache hit discounts
155
+ return layout
156
+
157
+ return layout
158
+
159
+ def measure_hit_rate(self, prefix_tokens: int, cache_hit: bool) -> None:
160
+ """Update cache statistics."""
161
+ if cache_hit:
162
+ self.cache_stats["warm_runs"] += 1
163
+ else:
164
+ self.cache_stats["cold_runs"] += 1
165
+
166
+ total = self.cache_stats["warm_runs"] + self.cache_stats["cold_runs"]
167
+ self.cache_stats["cache_hit_rate"] = self.cache_stats["warm_runs"] / total if total > 0 else 0.0
168
+
169
+ # Running average of prefix tokens
170
+ n = total
171
+ self.cache_stats["prefix_tokens_avg"] = (
172
+ (self.cache_stats["prefix_tokens_avg"] * (n - 1) + prefix_tokens) / n
173
+ )
174
+
175
+ def report(self) -> Dict[str, Any]:
176
+ """Generate cache performance report."""
177
+ total = self.cache_stats["warm_runs"] + self.cache_stats["cold_runs"]
178
+ return {
179
+ "total_runs": total,
180
+ "warm_runs": self.cache_stats["warm_runs"],
181
+ "cold_runs": self.cache_stats["cold_runs"],
182
+ "cache_hit_rate": self.cache_stats["cache_hit_rate"],
183
+ "avg_prefix_tokens": self.cache_stats["prefix_tokens_avg"],
184
+ "staleness_failures": self.cache_stats["staleness_failures"],
185
+ "estimated_cost_saved": self.cache_stats["warm_runs"] * self.cache_stats.get("avg_cache_discount", 0.0),
186
+ }