File size: 6,251 Bytes
07f0bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""Cache-Aware Prompt Layout - Module 5.

Optimizes prompt/context structure for prefix-cache reuse.

Strategy:
- Keep stable rules in the prefix (system rules, tool descriptions, user preferences)
- Keep tool descriptions stable
- Move dynamic content to the suffix (user message, retrieved docs, recent trace, artifacts)
- Avoid injecting timestamps/random metadata above cache boundary
- Preserve sticky provider/session routing where useful

Metrics:
- cache hit rate
- warm-cache cost
- cold-cache cost
- latency
- context staleness failures
"""

from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass

from .config import ACOConfig


@dataclass
class PromptLayout:
    prefix: str  # Stable, cacheable content
    suffix: str  # Dynamic content per turn
    prefix_tokens: int
    suffix_tokens: int
    cache_boundary_token: int
    estimated_cold_cost: float
    estimated_warm_cost: float
    cache_discount: float


class CacheAwarePromptLayout:
    """Lays out prompts to maximize prefix cache reuse."""

    # Content types that should stay in prefix
    PREFIX_CONTENT_TYPES = [
        "system_rules",
        "tool_descriptions", 
        "user_preferences",
        "static_examples",
        "persona_definition",
    ]

    # Content types that should be in suffix
    SUFFIX_CONTENT_TYPES = [
        "user_message",
        "retrieved_docs",
        "recent_trace",
        "artifacts",
        "timestamp",
        "session_id",
        "dynamic_examples",
        "conversation_history",
    ]

    def __init__(self, config: Optional[ACOConfig] = None):
        self.config = config or ACOConfig()
        self.cache_stats = {
            "cold_runs": 0,
            "warm_runs": 0,
            "prefix_tokens_avg": 0,
            "cache_hit_rate": 0.0,
            "staleness_failures": 0,
        }

    def layout(
        self,
        content_pieces: Dict[str, str],
        cost_per_1k_input: float = 0.01,
        cache_discount_rate: float = 0.5,
    ) -> PromptLayout:
        """Partition content into prefix (cacheable) and suffix (dynamic)."""
        
        prefix_pieces = []
        suffix_pieces = []
        
        for key, text in content_pieces.items():
            if self._is_prefix_content(key):
                prefix_pieces.append(text)
            else:
                suffix_pieces.append(text)
        
        # Sort prefix: most stable first
        prefix = "\n\n".join(prefix_pieces)
        suffix = "\n\n".join(suffix_pieces)
        
        # Token estimation (rough: 1 token ~ 4 chars for English)
        prefix_tokens = len(prefix) // 4
        suffix_tokens = len(suffix) // 4
        
        # Costs
        estimated_cold_cost = ((prefix_tokens + suffix_tokens) / 1000) * cost_per_1k_input
        estimated_warm_cost = ((suffix_tokens + prefix_tokens * (1 - cache_discount_rate)) / 1000) * cost_per_1k_input
        cache_discount = ((prefix_tokens * cache_discount_rate) / 1000) * cost_per_1k_input
        
        return PromptLayout(
            prefix=prefix,
            suffix=suffix,
            prefix_tokens=prefix_tokens,
            suffix_tokens=suffix_tokens,
            cache_boundary_token=prefix_tokens,
            estimated_cold_cost=estimated_cold_cost,
            estimated_warm_cost=estimated_warm_cost,
            cache_discount=cache_discount,
        )

    def _is_prefix_content(self, key: str) -> bool:
        """Determine if a content key belongs in the prefix."""
        # Direct matches
        if key in self.PREFIX_CONTENT_TYPES:
            return True
        if key in self.SUFFIX_CONTENT_TYPES:
            return False
        
        # Pattern matching
        if any(kw in key.lower() for kw in ["system", "static", "rule", "persona", "schema", "format"]):
            return True
        if any(kw in key.lower() for kw in ["user_", "dynamic", "current", "live", "now", "timestamp"]):
            return False
        
        # Default: prefix if name suggests stability
        return True

    def optimize_for_provider(
        self,
        layout: PromptLayout,
        provider: str,
    ) -> PromptLayout:
        """Provider-specific cache layout optimizations."""
        
        provider = provider.lower()
        
        if "anthropic" in provider:
            # Claude has system prompts that are automatically cached
            # Keep system content separate
            return layout
        
        elif "openai" in provider:
            # OpenAI has prefix caching on system + first user message
            # Ensure system content is at the very top
            return layout
        
        elif "gemini" in provider:
            # Gemini has context caching for repeated contexts
            return layout
        
        elif "deepseek" in provider:
            # DeepSeek has cache hit discounts
            return layout
        
        return layout

    def measure_hit_rate(self, prefix_tokens: int, cache_hit: bool) -> None:
        """Update cache statistics."""
        if cache_hit:
            self.cache_stats["warm_runs"] += 1
        else:
            self.cache_stats["cold_runs"] += 1
        
        total = self.cache_stats["warm_runs"] + self.cache_stats["cold_runs"]
        self.cache_stats["cache_hit_rate"] = self.cache_stats["warm_runs"] / total if total > 0 else 0.0
        
        # Running average of prefix tokens
        n = total
        self.cache_stats["prefix_tokens_avg"] = (
            (self.cache_stats["prefix_tokens_avg"] * (n - 1) + prefix_tokens) / n
        )

    def report(self) -> Dict[str, Any]:
        """Generate cache performance report."""
        total = self.cache_stats["warm_runs"] + self.cache_stats["cold_runs"]
        return {
            "total_runs": total,
            "warm_runs": self.cache_stats["warm_runs"],
            "cold_runs": self.cache_stats["cold_runs"],
            "cache_hit_rate": self.cache_stats["cache_hit_rate"],
            "avg_prefix_tokens": self.cache_stats["prefix_tokens_avg"],
            "staleness_failures": self.cache_stats["staleness_failures"],
            "estimated_cost_saved": self.cache_stats["warm_runs"] * self.cache_stats.get("avg_cache_discount", 0.0),
        }