File size: 4,633 Bytes
c8ece28
 
 
581261a
 
 
c8ece28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581261a
 
c8ece28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581261a
c8ece28
 
 
 
 
 
 
581261a
c8ece28
 
 
 
 
 
 
 
581261a
 
c8ece28
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Context Budgeter: Decides what context to include/exclude/summarize/retrieve."""
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass

@dataclass
class ContextBudget:
    total_tokens: int
    sources: Dict[str, int]  # source_name -> token_count
    keep_exact: List[str]
    summarize: List[str]
    omit: List[str]
    retrieve_on_demand: List[str]
    cache_prefix: List[str]
    dynamic_suffix: List[str]

SOURCE_PRIORITIES = {
    "system_rules": 1.0,       # Always include
    "tool_descriptions": 0.9,  # Almost always
    "recent_messages": 0.8,    # Important for coherence
    "task_plan": 0.7,          # Usually important
    "user_preferences": 0.6,
    "project_memory": 0.5,
    "prior_trace_failures": 0.5,
    "examples": 0.4,
    "retrieved_docs": 0.3,     # Retrieve on demand
    "artifacts": 0.3,
}

SOURCE_TOKEN_ESTIMATES = {
    "system_rules": 500,
    "tool_descriptions": 2000,
    "recent_messages": 1500,
    "task_plan": 300,
    "user_preferences": 100,
    "project_memory": 500,
    "prior_trace_failures": 300,
    "examples": 1000,
    "retrieved_docs": 3000,
    "artifacts": 1000,
}

TASK_CONTEXT_MULTIPLIERS = {
    "quick_answer": 0.3,
    "document_drafting": 0.6,
    "tool_heavy": 0.7,
    "retrieval_heavy": 1.2,
    "research": 1.0,
    "coding": 0.8,
    "unknown_ambiguous": 0.5,
    "long_horizon": 1.0,
    "legal_regulated": 1.3,
}

class ContextBudgeter:
    def __init__(self, max_context: int = 128000, default_budget: int = 8000):
        self.max_context = max_context
        self.default_budget = default_budget

    def budget(self, task_type: str, difficulty: int, needs_retrieval: bool,
               needs_tools: bool, has_prior_failures: bool = False,
               model_context_limit: int = None) -> ContextBudget:
        limit = model_context_limit or self.max_context
        mult = TASK_CONTEXT_MULTIPLIERS.get(task_type, 0.7)
        budget = int(self.default_budget * mult * (1 + difficulty * 0.2))
        budget = min(budget, limit)
        sources = {}
        keep_exact = []
        summarize = []
        omit = []
        retrieve_on_demand = []
        cache_prefix = []
        dynamic_suffix = []
        remaining = budget
        # Sort sources by priority
        sorted_sources = sorted(SOURCE_PRIORITIES.items(), key=lambda x: -x[1])
        for source, priority in sorted_sources:
            est_tokens = SOURCE_TOKEN_ESTIMATES.get(source, 500)
            # Check if this source is needed for this task
            needed = self._is_needed(source, task_type, needs_retrieval, needs_tools, has_prior_failures)
            if not needed:
                omit.append(source)
                continue
            if remaining >= est_tokens:
                if priority >= 0.7:
                    keep_exact.append(source)
                    cache_prefix.append(source) if priority >= 0.9 else dynamic_suffix.append(source)
                elif priority >= 0.4:
                    # Summarize high-token sources
                    if est_tokens > 800:
                        summarize.append(source)
                        est_tokens = min(300, est_tokens // 3)
                    else:
                        keep_exact.append(source)
                    dynamic_suffix.append(source)
                else:
                    retrieve_on_demand.append(source)
                    est_tokens = 0
                sources[source] = est_tokens
                remaining -= est_tokens
            else:
                if priority >= 0.7:
                    # Always include high-priority, even if truncated
                    keep_exact.append(source)
                    sources[source] = remaining
                    remaining = 0
                else:
                    retrieve_on_demand.append(source)
        return ContextBudget(
            total_tokens=budget,
            sources=sources,
            keep_exact=keep_exact,
            summarize=summarize,
            omit=omit,
            retrieve_on_demand=retrieve_on_demand,
            cache_prefix=cache_prefix,
            dynamic_suffix=dynamic_suffix,
        )

    def _is_needed(self, source: str, task_type: str, needs_retrieval: bool,
                   needs_tools: bool, has_failures: bool) -> bool:
        if source == "retrieved_docs" and not needs_retrieval: return False
        if source == "tool_descriptions" and not needs_tools: return False
        if source == "prior_trace_failures" and not has_failures: return False
        if source == "examples" and task_type == "quick_answer": return False
        return True