narcolepticchicken commited on
Commit
c8ece28
·
verified ·
1 Parent(s): e6100b5

Upload aco/context_budgeter.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/context_budgeter.py +116 -200
aco/context_budgeter.py CHANGED
@@ -1,209 +1,125 @@
1
- """Context Budgeter - Module 4.
2
-
3
- Decides what context is needed, what can be omitted/summarized, and what should be retrieved.
4
-
5
- Context sources:
6
- - system rules
7
- - tool descriptions
8
- - user preferences
9
- - project memory
10
- - retrieved docs
11
- - prior trace failures
12
- - examples
13
- - recent messages
14
- - artifacts
15
- - task plan
16
- """
17
-
18
- from typing import Dict, List, Tuple, Optional, Any
19
- from dataclasses import dataclass, field
20
-
21
- from .trace_schema import TaskType
22
- from .config import ACOConfig
23
-
24
-
25
- @dataclass
26
- class ContextSource:
27
- name: str
28
- tokens: int
29
- importance: float # 0-1
30
- staleness: float # 0=current, 1=very stale
31
- mutable: bool # True if content changes per turn
32
- cacheable: bool # True if can be prefix-cached
33
- summary: Optional[str] = None
34
-
35
 
36
  @dataclass
37
  class ContextBudget:
38
- total_budget_tokens: int
39
- allocated_sources: List[ContextSource]
40
- omitted_sources: List[ContextSource]
41
- summarized_sources: List[Tuple[ContextSource, str]] # (source, summary_text)
42
- retrieval_queries: List[str]
43
- cache_prefix_tokens: int
44
- dynamic_suffix_tokens: int
45
- estimated_cost: float
46
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  class ContextBudgeter:
49
- """Intelligently budgets context window to minimize cost while preserving quality."""
50
-
51
- # Task-specific context budgets (tokens)
52
- DEFAULT_BUDGETS = {
53
- TaskType.QUICK_ANSWER: 2048,
54
- TaskType.UNKNOWN_AMBIGUOUS: 4096,
55
- TaskType.TOOL_HEAVY: 8192,
56
- TaskType.RETRIEVAL_HEAVY: 16384,
57
- TaskType.DOCUMENT_DRAFTING: 8192,
58
- TaskType.CODING: 12288,
59
- TaskType.RESEARCH: 16384,
60
- TaskType.LONG_HORIZON: 32768,
61
- TaskType.LEGAL_REGULATED: 24576,
62
- }
63
-
64
- # Importance weights by task type
65
- IMPORTANCE_RULES = {
66
- TaskType.QUICK_ANSWER: {
67
- "system_rules": 0.9,
68
- "recent_messages": 0.9,
69
- "user_preferences": 0.5,
70
- },
71
- TaskType.CODING: {
72
- "system_rules": 0.8,
73
- "tool_descriptions": 0.9,
74
- "artifacts": 0.9,
75
- "recent_messages": 0.7,
76
- "examples": 0.6,
77
- },
78
- TaskType.RESEARCH: {
79
- "retrieved_docs": 0.95,
80
- "recent_messages": 0.6,
81
- "system_rules": 0.5,
82
- "task_plan": 0.8,
83
- },
84
- TaskType.LEGAL_REGULATED: {
85
- "retrieved_docs": 0.95,
86
- "system_rules": 0.9,
87
- "user_preferences": 0.7,
88
- "artifacts": 0.8,
89
- },
90
- TaskType.LONG_HORIZON: {
91
- "task_plan": 0.95,
92
- "recent_messages": 0.8,
93
- "artifacts": 0.85,
94
- "system_rules": 0.7,
95
- "prior_trace_failures": 0.6,
96
- },
97
- }
98
-
99
- def __init__(self, config: Optional[ACOConfig] = None):
100
- self.config = config or ACOConfig()
101
-
102
- def budget(
103
- self,
104
- task_type: TaskType,
105
- available_sources: List[ContextSource],
106
- model_max_context: int = 128000,
107
- cost_per_1k_input: float = 0.01,
108
- ) -> ContextBudget:
109
- """Allocate context budget across sources."""
110
-
111
- budget_tokens = self.DEFAULT_BUDGETS.get(task_type, 8192)
112
- # Don't exceed model limit
113
- budget_tokens = min(budget_tokens, int(model_max_context * 0.8))
114
-
115
- # Apply importance rules
116
- importance_map = self.IMPORTANCE_RULES.get(task_type, {})
117
- for source in available_sources:
118
- source.importance = importance_map.get(source.name, source.importance)
119
-
120
- # Separate stable (cacheable) vs dynamic
121
- stable_sources = [s for s in available_sources if s.cacheable and not s.mutable]
122
- dynamic_sources = [s for s in available_sources if s.mutable or not s.cacheable]
123
-
124
- # Always include stable sources in prefix (they're cache-efficient)
125
- prefix_tokens = sum(s.tokens for s in stable_sources)
126
- remaining = budget_tokens - prefix_tokens
127
-
128
- # Sort dynamic by importance / staleness ratio
129
- dynamic_sources.sort(key=lambda s: s.importance / (1 + s.staleness), reverse=True)
130
-
131
- allocated = list(stable_sources)
132
- omitted = []
133
- summarized = []
134
- retrieval_queries = []
135
-
136
- for source in dynamic_sources:
137
- if source.tokens <= remaining:
138
- allocated.append(source)
139
- remaining -= source.tokens
140
- elif source.importance > 0.7 and source.tokens > remaining * 1.5:
141
- # Source is important but too big — summarize it
142
- summary_tokens = min(int(remaining * 0.3), 512)
143
- if summary_tokens > 50:
144
- summary = self._summarize(source, summary_tokens)
145
- summarized.append((source, summary))
146
- remaining -= summary_tokens
147
- elif source.importance > 0.8:
148
- # Critical but doesn't fit — mark for retrieval instead
149
- omitted.append(source)
150
- retrieval_queries.append(f"retrieve:{source.name}")
151
  else:
152
- omitted.append(source)
153
-
154
- dynamic_used = sum(s.tokens for s in allocated if s in dynamic_sources)
155
- dynamic_used += sum(len(s[1].split()) for s in summarized)
156
-
157
- total_tokens = prefix_tokens + dynamic_used
158
- estimated_cost = (total_tokens / 1000) * cost_per_1k_input
159
-
160
  return ContextBudget(
161
- total_budget_tokens=budget_tokens,
162
- allocated_sources=allocated,
163
- omitted_sources=omitted,
164
- summarized_sources=summarized,
165
- retrieval_queries=retrieval_queries,
166
- cache_prefix_tokens=prefix_tokens,
167
- dynamic_suffix_tokens=dynamic_used,
168
- estimated_cost=estimated_cost,
169
  )
170
 
171
- def _summarize(self, source: ContextSource, max_tokens: int) -> str:
172
- """Produce a token-budgeted summary of a context source."""
173
- # In production, use a summarization model
174
- # Here we return a placeholder
175
- return f"[SUMMARY:{source.name}:{max_tokens}tokens]"
176
-
177
- def should_retrieve(self, source: ContextSource, task_type: TaskType) -> bool:
178
- """Decide if a source should be retrieved on-demand vs. kept in context."""
179
- if source.staleness > 0.5:
180
- return True
181
- if source.tokens > 4096 and source.importance < 0.8:
182
- return True
183
- if task_type in (TaskType.RESEARCH, TaskType.RETRIEVAL_HEAVY):
184
- return True
185
- return False
186
-
187
- def compress_history(
188
- self,
189
- messages: List[Dict[str, str]],
190
- max_messages: int = 10,
191
- summarize_older: bool = True,
192
- ) -> List[Dict[str, str]]:
193
- """Compress message history by summarizing older messages."""
194
- if len(messages) <= max_messages:
195
- return messages
196
-
197
- keep = messages[-max_messages:]
198
- older = messages[:-max_messages]
199
-
200
- if summarize_older and older:
201
- summary = self._summarize_messages(older)
202
- return [{"role": "system", "content": f"[Earlier context summary]: {summary}"}] + keep
203
-
204
- return keep
205
-
206
- def _summarize_messages(self, messages: List[Dict[str, str]]) -> str:
207
- """Summarize a list of messages."""
208
- # In production, use a summarization model
209
- return f"{len(messages)} earlier messages summarized."
 
1
+ """Context Budgeter: Decides what context to include/exclude/summarize/retrieve."""
2
+ from typing import Dict, List, Tuple, Optional
3
+ from dataclasses import dataclass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  @dataclass
6
  class ContextBudget:
7
+ total_tokens: int
8
+ sources: Dict[str, int] # source_name -> token_count
9
+ keep_exact: List[str]
10
+ summarize: List[str]
11
+ omit: List[str]
12
+ retrieve_on_demand: List[str]
13
+ cache_prefix: List[str]
14
+ dynamic_suffix: List[str]
15
+
16
+ SOURCE_PRIORITIES = {
17
+ "system_rules": 1.0, # Always include
18
+ "tool_descriptions": 0.9, # Almost always
19
+ "recent_messages": 0.8, # Important for coherence
20
+ "task_plan": 0.7, # Usually important
21
+ "user_preferences": 0.6,
22
+ "project_memory": 0.5,
23
+ "prior_trace_failures": 0.5,
24
+ "examples": 0.4,
25
+ "retrieved_docs": 0.3, # Retrieve on demand
26
+ "artifacts": 0.3,
27
+ }
28
+
29
+ SOURCE_TOKEN_ESTIMATES = {
30
+ "system_rules": 500,
31
+ "tool_descriptions": 2000,
32
+ "recent_messages": 1500,
33
+ "task_plan": 300,
34
+ "user_preferences": 100,
35
+ "project_memory": 500,
36
+ "prior_trace_failures": 300,
37
+ "examples": 1000,
38
+ "retrieved_docs": 3000,
39
+ "artifacts": 1000,
40
+ }
41
+
42
+ TASK_CONTEXT_MULTIPLIERS = {
43
+ "quick_answer": 0.3,
44
+ "document_drafting": 0.6,
45
+ "tool_heavy": 0.7,
46
+ "retrieval_heavy": 1.2,
47
+ "research": 1.0,
48
+ "coding": 0.8,
49
+ "unknown_ambiguous": 0.5,
50
+ "long_horizon": 1.0,
51
+ "legal_regulated": 1.3,
52
+ }
53
 
54
  class ContextBudgeter:
55
+ def __init__(self, max_context: int = 128000, default_budget: int = 8000):
56
+ self.max_context = max_context
57
+ self.default_budget = default_budget
58
+
59
+ def budget(self, task_type: str, difficulty: int, needs_retrieval: bool,
60
+ needs_tools: bool, has_prior_failures: bool = False,
61
+ model_context_limit: int = None) -> ContextBudget:
62
+ limit = model_context_limit or self.max_context
63
+ mult = TASK_CONTEXT_MULTIPLIERS.get(task_type, 0.7)
64
+ budget = int(self.default_budget * mult * (1 + difficulty * 0.2))
65
+ budget = min(budget, limit)
66
+ sources = {}
67
+ keep_exact = []
68
+ summarize = []
69
+ omit = []
70
+ retrieve_on_demand = []
71
+ cache_prefix = []
72
+ dynamic_suffix = []
73
+ remaining = budget
74
+ # Sort sources by priority
75
+ sorted_sources = sorted(SOURCE_PRIORITIES.items(), key=lambda x: -x[1])
76
+ for source, priority in sorted_sources:
77
+ est_tokens = SOURCE_TOKEN_ESTIMATES.get(source, 500)
78
+ # Check if this source is needed for this task
79
+ needed = self._is_needed(source, task_type, needs_retrieval, needs_tools, has_prior_failures)
80
+ if not needed:
81
+ omit.append(source)
82
+ continue
83
+ if remaining >= est_tokens:
84
+ if priority >= 0.7:
85
+ keep_exact.append(source)
86
+ cache_prefix.append(source) if priority >= 0.9 else dynamic_suffix.append(source)
87
+ elif priority >= 0.4:
88
+ # Summarize high-token sources
89
+ if est_tokens > 800:
90
+ summarize.append(source)
91
+ est_tokens = min(300, est_tokens // 3)
92
+ else:
93
+ keep_exact.append(source)
94
+ dynamic_suffix.append(source)
95
+ else:
96
+ retrieve_on_demand.append(source)
97
+ est_tokens = 0
98
+ sources[source] = est_tokens
99
+ remaining -= est_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  else:
101
+ if priority >= 0.7:
102
+ # Always include high-priority, even if truncated
103
+ keep_exact.append(source)
104
+ sources[source] = remaining
105
+ remaining = 0
106
+ else:
107
+ retrieve_on_demand.append(source)
 
108
  return ContextBudget(
109
+ total_tokens=budget,
110
+ sources=sources,
111
+ keep_exact=keep_exact,
112
+ summarize=summarize,
113
+ omit=omit,
114
+ retrieve_on_demand=retrieve_on_demand,
115
+ cache_prefix=cache_prefix,
116
+ dynamic_suffix=dynamic_suffix,
117
  )
118
 
119
+ def _is_needed(self, source: str, task_type: str, needs_retrieval: bool,
120
+ needs_tools: bool, has_failures: bool) -> bool:
121
+ if source == "retrieved_docs" and not needs_retrieval: return False
122
+ if source == "tool_descriptions" and not needs_tools: return False
123
+ if source == "prior_trace_failures" and not has_failures: return False
124
+ if source == "examples" and task_type == "quick_answer": return False
125
+ return True