File size: 3,206 Bytes
eaafe86
 
07f0bb4
 
 
 
eaafe86
 
07f0bb4
 
eaafe86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07f0bb4
eaafe86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07f0bb4
eaafe86
 
07f0bb4
 
eaafe86
 
 
07f0bb4
 
eaafe86
 
07f0bb4
eaafe86
 
 
 
 
07f0bb4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Cache-Aware Prompt Layout: Optimize prompt structure for prefix-cache reuse."""
from typing import Dict, List, Tuple
from dataclasses import dataclass

@dataclass
class PromptLayout:
    prefix_content: str
    suffix_content: str
    prefix_tokens: int
    suffix_tokens: int
    cache_boundary: int  # token position of cache boundary
    stable_sources: List[str]
    dynamic_sources: List[str]

CACHE_STABLE_SOURCES = {"system_rules", "tool_descriptions", "user_preferences"}
CACHE_DYNAMIC_SOURCES = {"recent_messages", "task_plan", "retrieved_docs", "artifacts"}

class CacheAwareLayout:
    def __init__(self, session_id: str = None):
        self.session_id = session_id
        self._prev_prefix_hash = None
        self.cache_hits = 0
        self.cache_misses = 0
        self.total_prefix_tokens = 0
        self.total_suffix_tokens = 0

    def layout(self, sources: Dict[str, str], context_budget) -> PromptLayout:
        prefix_parts = []
        suffix_parts = []
        stable = []
        dynamic = []
        prefix_tokens = 0
        suffix_tokens = 0
        for source_name, content in sources.items():
            token_est = len(content) // 4  # rough estimate
            if source_name in context_budget.cache_prefix:
                prefix_parts.append(f"# {source_name}\n{content}")
                prefix_tokens += token_est
                stable.append(source_name)
            elif source_name in context_budget.dynamic_suffix:
                suffix_parts.append(f"# {source_name}\n{content}")
                suffix_tokens += token_est
                dynamic.append(source_name)
            elif source_name in context_budget.keep_exact:
                prefix_parts.append(f"# {source_name}\n{content}")
                prefix_tokens += token_est
                stable.append(source_name)
            else:
                suffix_parts.append(f"# {source_name}\n{content}")
                suffix_tokens += token_est
                dynamic.append(source_name)
        prefix = "\n\n".join(prefix_parts)
        suffix = "\n\n".join(suffix_parts)
        # Check cache hit
        import hashlib
        prefix_hash = hashlib.md5(prefix.encode()).hexdigest()
        if self._prev_prefix_hash == prefix_hash:
            self.cache_hits += 1
        else:
            self.cache_misses += 1
        self._prev_prefix_hash = prefix_hash
        self.total_prefix_tokens += prefix_tokens
        self.total_suffix_tokens += suffix_tokens
        return PromptLayout(
            prefix_content=prefix,
            suffix_content=suffix,
            prefix_tokens=prefix_tokens,
            suffix_tokens=suffix_tokens,
            cache_boundary=prefix_tokens,
            stable_sources=stable,
            dynamic_sources=dynamic,
        )

    def stats(self) -> Dict:
        total = self.cache_hits + self.cache_misses
        return {
            "cache_hit_rate": self.cache_hits / max(total, 1),
            "total_cache_hits": self.cache_hits,
            "total_cache_misses": self.cache_misses,
            "avg_prefix_tokens": self.total_prefix_tokens / max(total, 1),
            "avg_suffix_tokens": self.total_suffix_tokens / max(total, 1),
        }