| """Context Budgeter: Decides what context to include/exclude/summarize/retrieve.""" |
| from typing import Dict, List, Tuple, Optional |
| from dataclasses import dataclass |
|
|
| @dataclass |
| class ContextBudget: |
| total_tokens: int |
| sources: Dict[str, int] |
| keep_exact: List[str] |
| summarize: List[str] |
| omit: List[str] |
| retrieve_on_demand: List[str] |
| cache_prefix: List[str] |
| dynamic_suffix: List[str] |
|
|
| SOURCE_PRIORITIES = { |
| "system_rules": 1.0, |
| "tool_descriptions": 0.9, |
| "recent_messages": 0.8, |
| "task_plan": 0.7, |
| "user_preferences": 0.6, |
| "project_memory": 0.5, |
| "prior_trace_failures": 0.5, |
| "examples": 0.4, |
| "retrieved_docs": 0.3, |
| "artifacts": 0.3, |
| } |
|
|
| SOURCE_TOKEN_ESTIMATES = { |
| "system_rules": 500, |
| "tool_descriptions": 2000, |
| "recent_messages": 1500, |
| "task_plan": 300, |
| "user_preferences": 100, |
| "project_memory": 500, |
| "prior_trace_failures": 300, |
| "examples": 1000, |
| "retrieved_docs": 3000, |
| "artifacts": 1000, |
| } |
|
|
| TASK_CONTEXT_MULTIPLIERS = { |
| "quick_answer": 0.3, |
| "document_drafting": 0.6, |
| "tool_heavy": 0.7, |
| "retrieval_heavy": 1.2, |
| "research": 1.0, |
| "coding": 0.8, |
| "unknown_ambiguous": 0.5, |
| "long_horizon": 1.0, |
| "legal_regulated": 1.3, |
| } |
|
|
| class ContextBudgeter: |
| def __init__(self, max_context: int = 128000, default_budget: int = 8000): |
| self.max_context = max_context |
| self.default_budget = default_budget |
|
|
| def budget(self, task_type: str, difficulty: int, needs_retrieval: bool, |
| needs_tools: bool, has_prior_failures: bool = False, |
| model_context_limit: int = None) -> ContextBudget: |
| limit = model_context_limit or self.max_context |
| mult = TASK_CONTEXT_MULTIPLIERS.get(task_type, 0.7) |
| budget = int(self.default_budget * mult * (1 + difficulty * 0.2)) |
| budget = min(budget, limit) |
| sources = {} |
| keep_exact = [] |
| summarize = [] |
| omit = [] |
| retrieve_on_demand = [] |
| cache_prefix = [] |
| dynamic_suffix = [] |
| remaining = budget |
| |
| sorted_sources = sorted(SOURCE_PRIORITIES.items(), key=lambda x: -x[1]) |
| for source, priority in sorted_sources: |
| est_tokens = SOURCE_TOKEN_ESTIMATES.get(source, 500) |
| |
| needed = self._is_needed(source, task_type, needs_retrieval, needs_tools, has_prior_failures) |
| if not needed: |
| omit.append(source) |
| continue |
| if remaining >= est_tokens: |
| if priority >= 0.7: |
| keep_exact.append(source) |
| cache_prefix.append(source) if priority >= 0.9 else dynamic_suffix.append(source) |
| elif priority >= 0.4: |
| |
| if est_tokens > 800: |
| summarize.append(source) |
| est_tokens = min(300, est_tokens // 3) |
| else: |
| keep_exact.append(source) |
| dynamic_suffix.append(source) |
| else: |
| retrieve_on_demand.append(source) |
| est_tokens = 0 |
| sources[source] = est_tokens |
| remaining -= est_tokens |
| else: |
| if priority >= 0.7: |
| |
| keep_exact.append(source) |
| sources[source] = remaining |
| remaining = 0 |
| else: |
| retrieve_on_demand.append(source) |
| return ContextBudget( |
| total_tokens=budget, |
| sources=sources, |
| keep_exact=keep_exact, |
| summarize=summarize, |
| omit=omit, |
| retrieve_on_demand=retrieve_on_demand, |
| cache_prefix=cache_prefix, |
| dynamic_suffix=dynamic_suffix, |
| ) |
|
|
| def _is_needed(self, source: str, task_type: str, needs_retrieval: bool, |
| needs_tools: bool, has_failures: bool) -> bool: |
| if source == "retrieved_docs" and not needs_retrieval: return False |
| if source == "tool_descriptions" and not needs_tools: return False |
| if source == "prior_trace_failures" and not has_failures: return False |
| if source == "examples" and task_type == "quick_answer": return False |
| return True |
|
|