File size: 5,603 Bytes
9146136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Per-Step Router: Route each agent step, not just the task.

Key insight from BAAR paper: Tasks get harder mid-execution.
A task that starts easy (search codebase) may end hard (patch critical bug).
The router should re-evaluate at each step based on:
- Current step type (search, read, edit, execute, verify)
- Results so far (errors found, tools failed)
- Remaining budget
- Confidence in current trajectory

Step types and their typical tier requirements:
- search/list: tier 1-2 (easy, pattern matching)
- read/understand: tier 2-3 (moderate comprehension)
- edit/patch: tier 3-4 (requires deep understanding)
- execute/test: tier 2-3 (pattern matching on errors)
- verify/review: tier 3-5 (requires expert judgment)
- plan/orchestrate: tier 4-5 (strategic thinking)
"""
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from enum import Enum

class StepType(Enum):
    SEARCH = "search"
    READ = "read"
    EDIT = "edit"
    EXECUTE = "execute"
    VERIFY = "verify"
    PLAN = "plan"
    COMMUNICATE = "communicate"
    UNKNOWN = "unknown"

# Base tier for each step type (before context adjustment)
STEP_TYPE_TIER = {
    StepType.SEARCH: 2,
    StepType.READ: 2,
    StepType.EDIT: 3,
    StepType.EXECUTE: 2,
    StepType.VERIFY: 3,
    StepType.PLAN: 4,
    StepType.COMMUNICATE: 2,
    StepType.UNKNOWN: 3,
}

# Step type detection from action description
STEP_PATTERNS = {
    StepType.SEARCH: ["search","find","grep","list","locate","look up","query"],
    StepType.READ: ["read","cat","show","display","view","inspect","open","check"],
    StepType.EDIT: ["edit","write","patch","modify","change","fix","update","insert","delete","replace"],
    StepType.EXECUTE: ["run","execute","test","compile","build","bash","shell","python"],
    StepType.VERIFY: ["verify","validate","check","review","confirm","assert","lint"],
    StepType.PLAN: ["plan","decide","choose","strategy","orchestrate","coordinate"],
    StepType.COMMUNICATE: ["ask","tell","inform","report","summarize","explain"],
}

@dataclass
class StepRoutingDecision:
    step_type: StepType
    base_tier: int
    adjusted_tier: int
    reason: str
    cost_estimate: float
    model_id: str

class PerStepRouter:
    def __init__(self, max_budget: float = 5.0, tier_costs: Dict[int,float] = None):
        self.max_budget = max_budget
        self.tier_costs = tier_costs or {1:0.01,2:0.05,3:0.15,4:0.30,5:0.50}
        self.step_history: List[Dict] = []
        self.budget_remaining: float = max_budget
        self.total_spent: float = 0.0
    
    def classify_step(self, action: str) -> StepType:
        r = action.lower()
        for stype, patterns in STEP_PATTERNS.items():
            if any(p in r for p in patterns):
                return stype
        return StepType.UNKNOWN
    
    def route_step(self, action: str, step_num: int,
                   has_prior_failures: bool = False,
                   num_errors_seen: int = 0,
                   task_risk: str = "medium") -> StepRoutingDecision:
        step_type = self.classify_step(action)
        base_tier = STEP_TYPE_TIER.get(step_type, 3)
        adjusted_tier = base_tier
        reasons = []
        
        # Adjustment 1: Prior failures → escalate
        if has_prior_failures:
            adjusted_tier = min(adjusted_tier + 1, 5)
            reasons.append("prior_failures→+1")
        
        # Adjustment 2: Many errors → need stronger model
        if num_errors_seen >= 3:
            adjusted_tier = min(adjusted_tier + 1, 5)
            reasons.append("many_errors→+1")
        
        # Adjustment 3: High-risk task → floor
        if task_risk == "critical" and step_type in (StepType.EDIT, StepType.VERIFY):
            adjusted_tier = max(adjusted_tier, 4)
            reasons.append("critical_risk→floor4")
        elif task_risk == "high" and step_type == StepType.EDIT:
            adjusted_tier = max(adjusted_tier, 3)
            reasons.append("high_risk→floor3")
        
        # Adjustment 4: Budget constraint → downgrade if needed
        step_cost = self.tier_costs.get(adjusted_tier, 0.15)
        if self.budget_remaining < step_cost * 2:
            # Downgrade to cheapest viable
            for t in range(1, adjusted_tier):
                if self.tier_costs.get(t, 1.0) * 2 <= self.budget_remaining:
                    if t >= STEP_TYPE_TIER.get(step_type, 2) - 1:
                        adjusted_tier = t
                        reasons.append("budget_constraint→downgrade")
                        break
        
        # Adjustment 5: Late step → can use cheaper (context is built up)
        if step_num > 5 and step_type in (StepType.READ, StepType.SEARCH):
            # Already have good context, cheaper model suffices
            adjusted_tier = max(adjusted_tier - 1, 1)
            reasons.append("late_step_context→-1")
        
        cost = self.tier_costs.get(adjusted_tier, 0.15)
        self.budget_remaining -= cost
        self.total_spent += cost
        
        model_ids = {1:"tiny-local-3b",2:"cheap-cloud-8b",3:"medium-70b",
                     4:"frontier-latest",5:"specialist-expert"}
        
        return StepRoutingDecision(
            step_type=step_type, base_tier=base_tier,
            adjusted_tier=adjusted_tier,
            reason="; ".join(reasons) if reasons else "default",
            cost_estimate=cost, model_id=model_ids.get(adjusted_tier, "medium-70b"),
        )
    
    def reset(self):
        self.step_history = []
        self.budget_remaining = self.max_budget
        self.total_spent = 0.0