narcolepticchicken commited on
Commit
eaafe86
·
verified ·
1 Parent(s): c8ece28

Upload aco/cache_layout.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/cache_layout.py +68 -172
aco/cache_layout.py CHANGED
@@ -1,186 +1,82 @@
1
- """Cache-Aware Prompt Layout - Module 5.
2
-
3
- Optimizes prompt/context structure for prefix-cache reuse.
4
-
5
- Strategy:
6
- - Keep stable rules in the prefix (system rules, tool descriptions, user preferences)
7
- - Keep tool descriptions stable
8
- - Move dynamic content to the suffix (user message, retrieved docs, recent trace, artifacts)
9
- - Avoid injecting timestamps/random metadata above cache boundary
10
- - Preserve sticky provider/session routing where useful
11
-
12
- Metrics:
13
- - cache hit rate
14
- - warm-cache cost
15
- - cold-cache cost
16
- - latency
17
- - context staleness failures
18
- """
19
-
20
- from typing import Dict, List, Tuple, Optional, Any
21
  from dataclasses import dataclass
22
 
23
- from .config import ACOConfig
24
-
25
-
26
  @dataclass
27
  class PromptLayout:
28
- prefix: str # Stable, cacheable content
29
- suffix: str # Dynamic content per turn
30
  prefix_tokens: int
31
  suffix_tokens: int
32
- cache_boundary_token: int
33
- estimated_cold_cost: float
34
- estimated_warm_cost: float
35
- cache_discount: float
36
-
37
-
38
- class CacheAwarePromptLayout:
39
- """Lays out prompts to maximize prefix cache reuse."""
40
-
41
- # Content types that should stay in prefix
42
- PREFIX_CONTENT_TYPES = [
43
- "system_rules",
44
- "tool_descriptions",
45
- "user_preferences",
46
- "static_examples",
47
- "persona_definition",
48
- ]
49
-
50
- # Content types that should be in suffix
51
- SUFFIX_CONTENT_TYPES = [
52
- "user_message",
53
- "retrieved_docs",
54
- "recent_trace",
55
- "artifacts",
56
- "timestamp",
57
- "session_id",
58
- "dynamic_examples",
59
- "conversation_history",
60
- ]
61
-
62
- def __init__(self, config: Optional[ACOConfig] = None):
63
- self.config = config or ACOConfig()
64
- self.cache_stats = {
65
- "cold_runs": 0,
66
- "warm_runs": 0,
67
- "prefix_tokens_avg": 0,
68
- "cache_hit_rate": 0.0,
69
- "staleness_failures": 0,
70
- }
71
-
72
- def layout(
73
- self,
74
- content_pieces: Dict[str, str],
75
- cost_per_1k_input: float = 0.01,
76
- cache_discount_rate: float = 0.5,
77
- ) -> PromptLayout:
78
- """Partition content into prefix (cacheable) and suffix (dynamic)."""
79
-
80
- prefix_pieces = []
81
- suffix_pieces = []
82
-
83
- for key, text in content_pieces.items():
84
- if self._is_prefix_content(key):
85
- prefix_pieces.append(text)
86
  else:
87
- suffix_pieces.append(text)
88
-
89
- # Sort prefix: most stable first
90
- prefix = "\n\n".join(prefix_pieces)
91
- suffix = "\n\n".join(suffix_pieces)
92
-
93
- # Token estimation (rough: 1 token ~ 4 chars for English)
94
- prefix_tokens = len(prefix) // 4
95
- suffix_tokens = len(suffix) // 4
96
-
97
- # Costs
98
- estimated_cold_cost = ((prefix_tokens + suffix_tokens) / 1000) * cost_per_1k_input
99
- estimated_warm_cost = ((suffix_tokens + prefix_tokens * (1 - cache_discount_rate)) / 1000) * cost_per_1k_input
100
- cache_discount = ((prefix_tokens * cache_discount_rate) / 1000) * cost_per_1k_input
101
-
102
  return PromptLayout(
103
- prefix=prefix,
104
- suffix=suffix,
105
  prefix_tokens=prefix_tokens,
106
  suffix_tokens=suffix_tokens,
107
- cache_boundary_token=prefix_tokens,
108
- estimated_cold_cost=estimated_cold_cost,
109
- estimated_warm_cost=estimated_warm_cost,
110
- cache_discount=cache_discount,
111
- )
112
-
113
- def _is_prefix_content(self, key: str) -> bool:
114
- """Determine if a content key belongs in the prefix."""
115
- # Direct matches
116
- if key in self.PREFIX_CONTENT_TYPES:
117
- return True
118
- if key in self.SUFFIX_CONTENT_TYPES:
119
- return False
120
-
121
- # Pattern matching
122
- if any(kw in key.lower() for kw in ["system", "static", "rule", "persona", "schema", "format"]):
123
- return True
124
- if any(kw in key.lower() for kw in ["user_", "dynamic", "current", "live", "now", "timestamp"]):
125
- return False
126
-
127
- # Default: prefix if name suggests stability
128
- return True
129
-
130
- def optimize_for_provider(
131
- self,
132
- layout: PromptLayout,
133
- provider: str,
134
- ) -> PromptLayout:
135
- """Provider-specific cache layout optimizations."""
136
-
137
- provider = provider.lower()
138
-
139
- if "anthropic" in provider:
140
- # Claude has system prompts that are automatically cached
141
- # Keep system content separate
142
- return layout
143
-
144
- elif "openai" in provider:
145
- # OpenAI has prefix caching on system + first user message
146
- # Ensure system content is at the very top
147
- return layout
148
-
149
- elif "gemini" in provider:
150
- # Gemini has context caching for repeated contexts
151
- return layout
152
-
153
- elif "deepseek" in provider:
154
- # DeepSeek has cache hit discounts
155
- return layout
156
-
157
- return layout
158
-
159
- def measure_hit_rate(self, prefix_tokens: int, cache_hit: bool) -> None:
160
- """Update cache statistics."""
161
- if cache_hit:
162
- self.cache_stats["warm_runs"] += 1
163
- else:
164
- self.cache_stats["cold_runs"] += 1
165
-
166
- total = self.cache_stats["warm_runs"] + self.cache_stats["cold_runs"]
167
- self.cache_stats["cache_hit_rate"] = self.cache_stats["warm_runs"] / total if total > 0 else 0.0
168
-
169
- # Running average of prefix tokens
170
- n = total
171
- self.cache_stats["prefix_tokens_avg"] = (
172
- (self.cache_stats["prefix_tokens_avg"] * (n - 1) + prefix_tokens) / n
173
  )
174
 
175
- def report(self) -> Dict[str, Any]:
176
- """Generate cache performance report."""
177
- total = self.cache_stats["warm_runs"] + self.cache_stats["cold_runs"]
178
  return {
179
- "total_runs": total,
180
- "warm_runs": self.cache_stats["warm_runs"],
181
- "cold_runs": self.cache_stats["cold_runs"],
182
- "cache_hit_rate": self.cache_stats["cache_hit_rate"],
183
- "avg_prefix_tokens": self.cache_stats["prefix_tokens_avg"],
184
- "staleness_failures": self.cache_stats["staleness_failures"],
185
- "estimated_cost_saved": self.cache_stats["warm_runs"] * self.cache_stats.get("avg_cache_discount", 0.0),
186
  }
 
1
+ """Cache-Aware Prompt Layout: Optimize prompt structure for prefix-cache reuse."""
2
+ from typing import Dict, List, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from dataclasses import dataclass
4
 
 
 
 
5
  @dataclass
6
  class PromptLayout:
7
+ prefix_content: str
8
+ suffix_content: str
9
  prefix_tokens: int
10
  suffix_tokens: int
11
+ cache_boundary: int # token position of cache boundary
12
+ stable_sources: List[str]
13
+ dynamic_sources: List[str]
14
+
15
+ CACHE_STABLE_SOURCES = {"system_rules", "tool_descriptions", "user_preferences"}
16
+ CACHE_DYNAMIC_SOURCES = {"recent_messages", "task_plan", "retrieved_docs", "artifacts"}
17
+
18
+ class CacheAwareLayout:
19
+ def __init__(self, session_id: str = None):
20
+ self.session_id = session_id
21
+ self._prev_prefix_hash = None
22
+ self.cache_hits = 0
23
+ self.cache_misses = 0
24
+ self.total_prefix_tokens = 0
25
+ self.total_suffix_tokens = 0
26
+
27
+ def layout(self, sources: Dict[str, str], context_budget) -> PromptLayout:
28
+ prefix_parts = []
29
+ suffix_parts = []
30
+ stable = []
31
+ dynamic = []
32
+ prefix_tokens = 0
33
+ suffix_tokens = 0
34
+ for source_name, content in sources.items():
35
+ token_est = len(content) // 4 # rough estimate
36
+ if source_name in context_budget.cache_prefix:
37
+ prefix_parts.append(f"# {source_name}\n{content}")
38
+ prefix_tokens += token_est
39
+ stable.append(source_name)
40
+ elif source_name in context_budget.dynamic_suffix:
41
+ suffix_parts.append(f"# {source_name}\n{content}")
42
+ suffix_tokens += token_est
43
+ dynamic.append(source_name)
44
+ elif source_name in context_budget.keep_exact:
45
+ prefix_parts.append(f"# {source_name}\n{content}")
46
+ prefix_tokens += token_est
47
+ stable.append(source_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  else:
49
+ suffix_parts.append(f"# {source_name}\n{content}")
50
+ suffix_tokens += token_est
51
+ dynamic.append(source_name)
52
+ prefix = "\n\n".join(prefix_parts)
53
+ suffix = "\n\n".join(suffix_parts)
54
+ # Check cache hit
55
+ import hashlib
56
+ prefix_hash = hashlib.md5(prefix.encode()).hexdigest()
57
+ if self._prev_prefix_hash == prefix_hash:
58
+ self.cache_hits += 1
59
+ else:
60
+ self.cache_misses += 1
61
+ self._prev_prefix_hash = prefix_hash
62
+ self.total_prefix_tokens += prefix_tokens
63
+ self.total_suffix_tokens += suffix_tokens
64
  return PromptLayout(
65
+ prefix_content=prefix,
66
+ suffix_content=suffix,
67
  prefix_tokens=prefix_tokens,
68
  suffix_tokens=suffix_tokens,
69
+ cache_boundary=prefix_tokens,
70
+ stable_sources=stable,
71
+ dynamic_sources=dynamic,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  )
73
 
74
+ def stats(self) -> Dict:
75
+ total = self.cache_hits + self.cache_misses
 
76
  return {
77
+ "cache_hit_rate": self.cache_hits / max(total, 1),
78
+ "total_cache_hits": self.cache_hits,
79
+ "total_cache_misses": self.cache_misses,
80
+ "avg_prefix_tokens": self.total_prefix_tokens / max(total, 1),
81
+ "avg_suffix_tokens": self.total_suffix_tokens / max(total, 1),
 
 
82
  }