narcolepticchicken commited on
Commit
581261a
·
verified ·
1 Parent(s): 2b522a0

Upload aco/context_budgeter.py

Browse files
Files changed (1) hide show
  1. aco/context_budgeter.py +209 -0
aco/context_budgeter.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Context Budgeter - Module 4.
2
+
3
+ Decides what context is needed, what can be omitted/summarized, and what should be retrieved.
4
+
5
+ Context sources:
6
+ - system rules
7
+ - tool descriptions
8
+ - user preferences
9
+ - project memory
10
+ - retrieved docs
11
+ - prior trace failures
12
+ - examples
13
+ - recent messages
14
+ - artifacts
15
+ - task plan
16
+ """
17
+
18
+ from typing import Dict, List, Tuple, Optional, Any
19
+ from dataclasses import dataclass, field
20
+
21
+ from .trace_schema import TaskType
22
+ from .config import ACOConfig
23
+
24
+
25
+ @dataclass
26
+ class ContextSource:
27
+ name: str
28
+ tokens: int
29
+ importance: float # 0-1
30
+ staleness: float # 0=current, 1=very stale
31
+ mutable: bool # True if content changes per turn
32
+ cacheable: bool # True if can be prefix-cached
33
+ summary: Optional[str] = None
34
+
35
+
36
+ @dataclass
37
+ class ContextBudget:
38
+ total_budget_tokens: int
39
+ allocated_sources: List[ContextSource]
40
+ omitted_sources: List[ContextSource]
41
+ summarized_sources: List[Tuple[ContextSource, str]] # (source, summary_text)
42
+ retrieval_queries: List[str]
43
+ cache_prefix_tokens: int
44
+ dynamic_suffix_tokens: int
45
+ estimated_cost: float
46
+
47
+
48
+ class ContextBudgeter:
49
+ """Intelligently budgets context window to minimize cost while preserving quality."""
50
+
51
+ # Task-specific context budgets (tokens)
52
+ DEFAULT_BUDGETS = {
53
+ TaskType.QUICK_ANSWER: 2048,
54
+ TaskType.UNKNOWN_AMBIGUOUS: 4096,
55
+ TaskType.TOOL_HEAVY: 8192,
56
+ TaskType.RETRIEVAL_HEAVY: 16384,
57
+ TaskType.DOCUMENT_DRAFTING: 8192,
58
+ TaskType.CODING: 12288,
59
+ TaskType.RESEARCH: 16384,
60
+ TaskType.LONG_HORIZON: 32768,
61
+ TaskType.LEGAL_REGULATED: 24576,
62
+ }
63
+
64
+ # Importance weights by task type
65
+ IMPORTANCE_RULES = {
66
+ TaskType.QUICK_ANSWER: {
67
+ "system_rules": 0.9,
68
+ "recent_messages": 0.9,
69
+ "user_preferences": 0.5,
70
+ },
71
+ TaskType.CODING: {
72
+ "system_rules": 0.8,
73
+ "tool_descriptions": 0.9,
74
+ "artifacts": 0.9,
75
+ "recent_messages": 0.7,
76
+ "examples": 0.6,
77
+ },
78
+ TaskType.RESEARCH: {
79
+ "retrieved_docs": 0.95,
80
+ "recent_messages": 0.6,
81
+ "system_rules": 0.5,
82
+ "task_plan": 0.8,
83
+ },
84
+ TaskType.LEGAL_REGULATED: {
85
+ "retrieved_docs": 0.95,
86
+ "system_rules": 0.9,
87
+ "user_preferences": 0.7,
88
+ "artifacts": 0.8,
89
+ },
90
+ TaskType.LONG_HORIZON: {
91
+ "task_plan": 0.95,
92
+ "recent_messages": 0.8,
93
+ "artifacts": 0.85,
94
+ "system_rules": 0.7,
95
+ "prior_trace_failures": 0.6,
96
+ },
97
+ }
98
+
99
+ def __init__(self, config: Optional[ACOConfig] = None):
100
+ self.config = config or ACOConfig()
101
+
102
+ def budget(
103
+ self,
104
+ task_type: TaskType,
105
+ available_sources: List[ContextSource],
106
+ model_max_context: int = 128000,
107
+ cost_per_1k_input: float = 0.01,
108
+ ) -> ContextBudget:
109
+ """Allocate context budget across sources."""
110
+
111
+ budget_tokens = self.DEFAULT_BUDGETS.get(task_type, 8192)
112
+ # Don't exceed model limit
113
+ budget_tokens = min(budget_tokens, int(model_max_context * 0.8))
114
+
115
+ # Apply importance rules
116
+ importance_map = self.IMPORTANCE_RULES.get(task_type, {})
117
+ for source in available_sources:
118
+ source.importance = importance_map.get(source.name, source.importance)
119
+
120
+ # Separate stable (cacheable) vs dynamic
121
+ stable_sources = [s for s in available_sources if s.cacheable and not s.mutable]
122
+ dynamic_sources = [s for s in available_sources if s.mutable or not s.cacheable]
123
+
124
+ # Always include stable sources in prefix (they're cache-efficient)
125
+ prefix_tokens = sum(s.tokens for s in stable_sources)
126
+ remaining = budget_tokens - prefix_tokens
127
+
128
+ # Sort dynamic by importance / staleness ratio
129
+ dynamic_sources.sort(key=lambda s: s.importance / (1 + s.staleness), reverse=True)
130
+
131
+ allocated = list(stable_sources)
132
+ omitted = []
133
+ summarized = []
134
+ retrieval_queries = []
135
+
136
+ for source in dynamic_sources:
137
+ if source.tokens <= remaining:
138
+ allocated.append(source)
139
+ remaining -= source.tokens
140
+ elif source.importance > 0.7 and source.tokens > remaining * 1.5:
141
+ # Source is important but too big — summarize it
142
+ summary_tokens = min(int(remaining * 0.3), 512)
143
+ if summary_tokens > 50:
144
+ summary = self._summarize(source, summary_tokens)
145
+ summarized.append((source, summary))
146
+ remaining -= summary_tokens
147
+ elif source.importance > 0.8:
148
+ # Critical but doesn't fit — mark for retrieval instead
149
+ omitted.append(source)
150
+ retrieval_queries.append(f"retrieve:{source.name}")
151
+ else:
152
+ omitted.append(source)
153
+
154
+ dynamic_used = sum(s.tokens for s in allocated if s in dynamic_sources)
155
+ dynamic_used += sum(len(s[1].split()) for s in summarized)
156
+
157
+ total_tokens = prefix_tokens + dynamic_used
158
+ estimated_cost = (total_tokens / 1000) * cost_per_1k_input
159
+
160
+ return ContextBudget(
161
+ total_budget_tokens=budget_tokens,
162
+ allocated_sources=allocated,
163
+ omitted_sources=omitted,
164
+ summarized_sources=summarized,
165
+ retrieval_queries=retrieval_queries,
166
+ cache_prefix_tokens=prefix_tokens,
167
+ dynamic_suffix_tokens=dynamic_used,
168
+ estimated_cost=estimated_cost,
169
+ )
170
+
171
+ def _summarize(self, source: ContextSource, max_tokens: int) -> str:
172
+ """Produce a token-budgeted summary of a context source."""
173
+ # In production, use a summarization model
174
+ # Here we return a placeholder
175
+ return f"[SUMMARY:{source.name}:{max_tokens}tokens]"
176
+
177
+ def should_retrieve(self, source: ContextSource, task_type: TaskType) -> bool:
178
+ """Decide if a source should be retrieved on-demand vs. kept in context."""
179
+ if source.staleness > 0.5:
180
+ return True
181
+ if source.tokens > 4096 and source.importance < 0.8:
182
+ return True
183
+ if task_type in (TaskType.RESEARCH, TaskType.RETRIEVAL_HEAVY):
184
+ return True
185
+ return False
186
+
187
+ def compress_history(
188
+ self,
189
+ messages: List[Dict[str, str]],
190
+ max_messages: int = 10,
191
+ summarize_older: bool = True,
192
+ ) -> List[Dict[str, str]]:
193
+ """Compress message history by summarizing older messages."""
194
+ if len(messages) <= max_messages:
195
+ return messages
196
+
197
+ keep = messages[-max_messages:]
198
+ older = messages[:-max_messages]
199
+
200
+ if summarize_older and older:
201
+ summary = self._summarize_messages(older)
202
+ return [{"role": "system", "content": f"[Earlier context summary]: {summary}"}] + keep
203
+
204
+ return keep
205
+
206
+ def _summarize_messages(self, messages: List[Dict[str, str]]) -> str:
207
+ """Summarize a list of messages."""
208
+ # In production, use a summarization model
209
+ return f"{len(messages)} earlier messages summarized."