agent-cost-optimizer / aco /context_budgeter.py
narcolepticchicken's picture
Upload aco/context_budgeter.py with huggingface_hub
c8ece28 verified
"""Context Budgeter: Decides what context to include/exclude/summarize/retrieve."""
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
@dataclass
class ContextBudget:
total_tokens: int
sources: Dict[str, int] # source_name -> token_count
keep_exact: List[str]
summarize: List[str]
omit: List[str]
retrieve_on_demand: List[str]
cache_prefix: List[str]
dynamic_suffix: List[str]
SOURCE_PRIORITIES = {
"system_rules": 1.0, # Always include
"tool_descriptions": 0.9, # Almost always
"recent_messages": 0.8, # Important for coherence
"task_plan": 0.7, # Usually important
"user_preferences": 0.6,
"project_memory": 0.5,
"prior_trace_failures": 0.5,
"examples": 0.4,
"retrieved_docs": 0.3, # Retrieve on demand
"artifacts": 0.3,
}
SOURCE_TOKEN_ESTIMATES = {
"system_rules": 500,
"tool_descriptions": 2000,
"recent_messages": 1500,
"task_plan": 300,
"user_preferences": 100,
"project_memory": 500,
"prior_trace_failures": 300,
"examples": 1000,
"retrieved_docs": 3000,
"artifacts": 1000,
}
TASK_CONTEXT_MULTIPLIERS = {
"quick_answer": 0.3,
"document_drafting": 0.6,
"tool_heavy": 0.7,
"retrieval_heavy": 1.2,
"research": 1.0,
"coding": 0.8,
"unknown_ambiguous": 0.5,
"long_horizon": 1.0,
"legal_regulated": 1.3,
}
class ContextBudgeter:
def __init__(self, max_context: int = 128000, default_budget: int = 8000):
self.max_context = max_context
self.default_budget = default_budget
def budget(self, task_type: str, difficulty: int, needs_retrieval: bool,
needs_tools: bool, has_prior_failures: bool = False,
model_context_limit: int = None) -> ContextBudget:
limit = model_context_limit or self.max_context
mult = TASK_CONTEXT_MULTIPLIERS.get(task_type, 0.7)
budget = int(self.default_budget * mult * (1 + difficulty * 0.2))
budget = min(budget, limit)
sources = {}
keep_exact = []
summarize = []
omit = []
retrieve_on_demand = []
cache_prefix = []
dynamic_suffix = []
remaining = budget
# Sort sources by priority
sorted_sources = sorted(SOURCE_PRIORITIES.items(), key=lambda x: -x[1])
for source, priority in sorted_sources:
est_tokens = SOURCE_TOKEN_ESTIMATES.get(source, 500)
# Check if this source is needed for this task
needed = self._is_needed(source, task_type, needs_retrieval, needs_tools, has_prior_failures)
if not needed:
omit.append(source)
continue
if remaining >= est_tokens:
if priority >= 0.7:
keep_exact.append(source)
cache_prefix.append(source) if priority >= 0.9 else dynamic_suffix.append(source)
elif priority >= 0.4:
# Summarize high-token sources
if est_tokens > 800:
summarize.append(source)
est_tokens = min(300, est_tokens // 3)
else:
keep_exact.append(source)
dynamic_suffix.append(source)
else:
retrieve_on_demand.append(source)
est_tokens = 0
sources[source] = est_tokens
remaining -= est_tokens
else:
if priority >= 0.7:
# Always include high-priority, even if truncated
keep_exact.append(source)
sources[source] = remaining
remaining = 0
else:
retrieve_on_demand.append(source)
return ContextBudget(
total_tokens=budget,
sources=sources,
keep_exact=keep_exact,
summarize=summarize,
omit=omit,
retrieve_on_demand=retrieve_on_demand,
cache_prefix=cache_prefix,
dynamic_suffix=dynamic_suffix,
)
def _is_needed(self, source: str, task_type: str, needs_retrieval: bool,
needs_tools: bool, has_failures: bool) -> bool:
if source == "retrieved_docs" and not needs_retrieval: return False
if source == "tool_descriptions" and not needs_tools: return False
if source == "prior_trace_failures" and not has_failures: return False
if source == "examples" and task_type == "quick_answer": return False
return True