narcolepticchicken commited on
Commit
9146136
·
verified ·
1 Parent(s): 8191bbb

Upload aco/per_step_router.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/per_step_router.py +141 -0
aco/per_step_router.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Per-Step Router: Route each agent step, not just the task.
2
+
3
+ Key insight from BAAR paper: Tasks get harder mid-execution.
4
+ A task that starts easy (search codebase) may end hard (patch critical bug).
5
+ The router should re-evaluate at each step based on:
6
+ - Current step type (search, read, edit, execute, verify)
7
+ - Results so far (errors found, tools failed)
8
+ - Remaining budget
9
+ - Confidence in current trajectory
10
+
11
+ Step types and their typical tier requirements:
12
+ - search/list: tier 1-2 (easy, pattern matching)
13
+ - read/understand: tier 2-3 (moderate comprehension)
14
+ - edit/patch: tier 3-4 (requires deep understanding)
15
+ - execute/test: tier 2-3 (pattern matching on errors)
16
+ - verify/review: tier 3-5 (requires expert judgment)
17
+ - plan/orchestrate: tier 4-5 (strategic thinking)
18
+ """
19
+ from typing import Dict, List, Optional, Tuple
20
+ from dataclasses import dataclass
21
+ from enum import Enum
22
+
23
+ class StepType(Enum):
24
+ SEARCH = "search"
25
+ READ = "read"
26
+ EDIT = "edit"
27
+ EXECUTE = "execute"
28
+ VERIFY = "verify"
29
+ PLAN = "plan"
30
+ COMMUNICATE = "communicate"
31
+ UNKNOWN = "unknown"
32
+
33
+ # Base tier for each step type (before context adjustment)
34
+ STEP_TYPE_TIER = {
35
+ StepType.SEARCH: 2,
36
+ StepType.READ: 2,
37
+ StepType.EDIT: 3,
38
+ StepType.EXECUTE: 2,
39
+ StepType.VERIFY: 3,
40
+ StepType.PLAN: 4,
41
+ StepType.COMMUNICATE: 2,
42
+ StepType.UNKNOWN: 3,
43
+ }
44
+
45
+ # Step type detection from action description
46
+ STEP_PATTERNS = {
47
+ StepType.SEARCH: ["search","find","grep","list","locate","look up","query"],
48
+ StepType.READ: ["read","cat","show","display","view","inspect","open","check"],
49
+ StepType.EDIT: ["edit","write","patch","modify","change","fix","update","insert","delete","replace"],
50
+ StepType.EXECUTE: ["run","execute","test","compile","build","bash","shell","python"],
51
+ StepType.VERIFY: ["verify","validate","check","review","confirm","assert","lint"],
52
+ StepType.PLAN: ["plan","decide","choose","strategy","orchestrate","coordinate"],
53
+ StepType.COMMUNICATE: ["ask","tell","inform","report","summarize","explain"],
54
+ }
55
+
56
+ @dataclass
57
+ class StepRoutingDecision:
58
+ step_type: StepType
59
+ base_tier: int
60
+ adjusted_tier: int
61
+ reason: str
62
+ cost_estimate: float
63
+ model_id: str
64
+
65
+ class PerStepRouter:
66
+ def __init__(self, max_budget: float = 5.0, tier_costs: Dict[int,float] = None):
67
+ self.max_budget = max_budget
68
+ self.tier_costs = tier_costs or {1:0.01,2:0.05,3:0.15,4:0.30,5:0.50}
69
+ self.step_history: List[Dict] = []
70
+ self.budget_remaining: float = max_budget
71
+ self.total_spent: float = 0.0
72
+
73
+ def classify_step(self, action: str) -> StepType:
74
+ r = action.lower()
75
+ for stype, patterns in STEP_PATTERNS.items():
76
+ if any(p in r for p in patterns):
77
+ return stype
78
+ return StepType.UNKNOWN
79
+
80
+ def route_step(self, action: str, step_num: int,
81
+ has_prior_failures: bool = False,
82
+ num_errors_seen: int = 0,
83
+ task_risk: str = "medium") -> StepRoutingDecision:
84
+ step_type = self.classify_step(action)
85
+ base_tier = STEP_TYPE_TIER.get(step_type, 3)
86
+ adjusted_tier = base_tier
87
+ reasons = []
88
+
89
+ # Adjustment 1: Prior failures → escalate
90
+ if has_prior_failures:
91
+ adjusted_tier = min(adjusted_tier + 1, 5)
92
+ reasons.append("prior_failures→+1")
93
+
94
+ # Adjustment 2: Many errors → need stronger model
95
+ if num_errors_seen >= 3:
96
+ adjusted_tier = min(adjusted_tier + 1, 5)
97
+ reasons.append("many_errors→+1")
98
+
99
+ # Adjustment 3: High-risk task → floor
100
+ if task_risk == "critical" and step_type in (StepType.EDIT, StepType.VERIFY):
101
+ adjusted_tier = max(adjusted_tier, 4)
102
+ reasons.append("critical_risk→floor4")
103
+ elif task_risk == "high" and step_type == StepType.EDIT:
104
+ adjusted_tier = max(adjusted_tier, 3)
105
+ reasons.append("high_risk→floor3")
106
+
107
+ # Adjustment 4: Budget constraint → downgrade if needed
108
+ step_cost = self.tier_costs.get(adjusted_tier, 0.15)
109
+ if self.budget_remaining < step_cost * 2:
110
+ # Downgrade to cheapest viable
111
+ for t in range(1, adjusted_tier):
112
+ if self.tier_costs.get(t, 1.0) * 2 <= self.budget_remaining:
113
+ if t >= STEP_TYPE_TIER.get(step_type, 2) - 1:
114
+ adjusted_tier = t
115
+ reasons.append("budget_constraint→downgrade")
116
+ break
117
+
118
+ # Adjustment 5: Late step → can use cheaper (context is built up)
119
+ if step_num > 5 and step_type in (StepType.READ, StepType.SEARCH):
120
+ # Already have good context, cheaper model suffices
121
+ adjusted_tier = max(adjusted_tier - 1, 1)
122
+ reasons.append("late_step_context→-1")
123
+
124
+ cost = self.tier_costs.get(adjusted_tier, 0.15)
125
+ self.budget_remaining -= cost
126
+ self.total_spent += cost
127
+
128
+ model_ids = {1:"tiny-local-3b",2:"cheap-cloud-8b",3:"medium-70b",
129
+ 4:"frontier-latest",5:"specialist-expert"}
130
+
131
+ return StepRoutingDecision(
132
+ step_type=step_type, base_tier=base_tier,
133
+ adjusted_tier=adjusted_tier,
134
+ reason="; ".join(reasons) if reasons else "default",
135
+ cost_estimate=cost, model_id=model_ids.get(adjusted_tier, "medium-70b"),
136
+ )
137
+
138
+ def reset(self):
139
+ self.step_history = []
140
+ self.budget_remaining = self.max_budget
141
+ self.total_spent = 0.0