Upload aco/per_step_router.py with huggingface_hub
Browse files- aco/per_step_router.py +141 -0
aco/per_step_router.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Per-Step Router: Route each agent step, not just the task.
|
| 2 |
+
|
| 3 |
+
Key insight from BAAR paper: Tasks get harder mid-execution.
|
| 4 |
+
A task that starts easy (search codebase) may end hard (patch critical bug).
|
| 5 |
+
The router should re-evaluate at each step based on:
|
| 6 |
+
- Current step type (search, read, edit, execute, verify)
|
| 7 |
+
- Results so far (errors found, tools failed)
|
| 8 |
+
- Remaining budget
|
| 9 |
+
- Confidence in current trajectory
|
| 10 |
+
|
| 11 |
+
Step types and their typical tier requirements:
|
| 12 |
+
- search/list: tier 1-2 (easy, pattern matching)
|
| 13 |
+
- read/understand: tier 2-3 (moderate comprehension)
|
| 14 |
+
- edit/patch: tier 3-4 (requires deep understanding)
|
| 15 |
+
- execute/test: tier 2-3 (pattern matching on errors)
|
| 16 |
+
- verify/review: tier 3-5 (requires expert judgment)
|
| 17 |
+
- plan/orchestrate: tier 4-5 (strategic thinking)
|
| 18 |
+
"""
|
| 19 |
+
from typing import Dict, List, Optional, Tuple
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from enum import Enum
|
| 22 |
+
|
| 23 |
+
class StepType(Enum):
|
| 24 |
+
SEARCH = "search"
|
| 25 |
+
READ = "read"
|
| 26 |
+
EDIT = "edit"
|
| 27 |
+
EXECUTE = "execute"
|
| 28 |
+
VERIFY = "verify"
|
| 29 |
+
PLAN = "plan"
|
| 30 |
+
COMMUNICATE = "communicate"
|
| 31 |
+
UNKNOWN = "unknown"
|
| 32 |
+
|
| 33 |
+
# Base tier for each step type (before context adjustment)
|
| 34 |
+
STEP_TYPE_TIER = {
|
| 35 |
+
StepType.SEARCH: 2,
|
| 36 |
+
StepType.READ: 2,
|
| 37 |
+
StepType.EDIT: 3,
|
| 38 |
+
StepType.EXECUTE: 2,
|
| 39 |
+
StepType.VERIFY: 3,
|
| 40 |
+
StepType.PLAN: 4,
|
| 41 |
+
StepType.COMMUNICATE: 2,
|
| 42 |
+
StepType.UNKNOWN: 3,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# Step type detection from action description
|
| 46 |
+
STEP_PATTERNS = {
|
| 47 |
+
StepType.SEARCH: ["search","find","grep","list","locate","look up","query"],
|
| 48 |
+
StepType.READ: ["read","cat","show","display","view","inspect","open","check"],
|
| 49 |
+
StepType.EDIT: ["edit","write","patch","modify","change","fix","update","insert","delete","replace"],
|
| 50 |
+
StepType.EXECUTE: ["run","execute","test","compile","build","bash","shell","python"],
|
| 51 |
+
StepType.VERIFY: ["verify","validate","check","review","confirm","assert","lint"],
|
| 52 |
+
StepType.PLAN: ["plan","decide","choose","strategy","orchestrate","coordinate"],
|
| 53 |
+
StepType.COMMUNICATE: ["ask","tell","inform","report","summarize","explain"],
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class StepRoutingDecision:
|
| 58 |
+
step_type: StepType
|
| 59 |
+
base_tier: int
|
| 60 |
+
adjusted_tier: int
|
| 61 |
+
reason: str
|
| 62 |
+
cost_estimate: float
|
| 63 |
+
model_id: str
|
| 64 |
+
|
| 65 |
+
class PerStepRouter:
|
| 66 |
+
def __init__(self, max_budget: float = 5.0, tier_costs: Dict[int,float] = None):
|
| 67 |
+
self.max_budget = max_budget
|
| 68 |
+
self.tier_costs = tier_costs or {1:0.01,2:0.05,3:0.15,4:0.30,5:0.50}
|
| 69 |
+
self.step_history: List[Dict] = []
|
| 70 |
+
self.budget_remaining: float = max_budget
|
| 71 |
+
self.total_spent: float = 0.0
|
| 72 |
+
|
| 73 |
+
def classify_step(self, action: str) -> StepType:
|
| 74 |
+
r = action.lower()
|
| 75 |
+
for stype, patterns in STEP_PATTERNS.items():
|
| 76 |
+
if any(p in r for p in patterns):
|
| 77 |
+
return stype
|
| 78 |
+
return StepType.UNKNOWN
|
| 79 |
+
|
| 80 |
+
def route_step(self, action: str, step_num: int,
|
| 81 |
+
has_prior_failures: bool = False,
|
| 82 |
+
num_errors_seen: int = 0,
|
| 83 |
+
task_risk: str = "medium") -> StepRoutingDecision:
|
| 84 |
+
step_type = self.classify_step(action)
|
| 85 |
+
base_tier = STEP_TYPE_TIER.get(step_type, 3)
|
| 86 |
+
adjusted_tier = base_tier
|
| 87 |
+
reasons = []
|
| 88 |
+
|
| 89 |
+
# Adjustment 1: Prior failures → escalate
|
| 90 |
+
if has_prior_failures:
|
| 91 |
+
adjusted_tier = min(adjusted_tier + 1, 5)
|
| 92 |
+
reasons.append("prior_failures→+1")
|
| 93 |
+
|
| 94 |
+
# Adjustment 2: Many errors → need stronger model
|
| 95 |
+
if num_errors_seen >= 3:
|
| 96 |
+
adjusted_tier = min(adjusted_tier + 1, 5)
|
| 97 |
+
reasons.append("many_errors→+1")
|
| 98 |
+
|
| 99 |
+
# Adjustment 3: High-risk task → floor
|
| 100 |
+
if task_risk == "critical" and step_type in (StepType.EDIT, StepType.VERIFY):
|
| 101 |
+
adjusted_tier = max(adjusted_tier, 4)
|
| 102 |
+
reasons.append("critical_risk→floor4")
|
| 103 |
+
elif task_risk == "high" and step_type == StepType.EDIT:
|
| 104 |
+
adjusted_tier = max(adjusted_tier, 3)
|
| 105 |
+
reasons.append("high_risk→floor3")
|
| 106 |
+
|
| 107 |
+
# Adjustment 4: Budget constraint → downgrade if needed
|
| 108 |
+
step_cost = self.tier_costs.get(adjusted_tier, 0.15)
|
| 109 |
+
if self.budget_remaining < step_cost * 2:
|
| 110 |
+
# Downgrade to cheapest viable
|
| 111 |
+
for t in range(1, adjusted_tier):
|
| 112 |
+
if self.tier_costs.get(t, 1.0) * 2 <= self.budget_remaining:
|
| 113 |
+
if t >= STEP_TYPE_TIER.get(step_type, 2) - 1:
|
| 114 |
+
adjusted_tier = t
|
| 115 |
+
reasons.append("budget_constraint→downgrade")
|
| 116 |
+
break
|
| 117 |
+
|
| 118 |
+
# Adjustment 5: Late step → can use cheaper (context is built up)
|
| 119 |
+
if step_num > 5 and step_type in (StepType.READ, StepType.SEARCH):
|
| 120 |
+
# Already have good context, cheaper model suffices
|
| 121 |
+
adjusted_tier = max(adjusted_tier - 1, 1)
|
| 122 |
+
reasons.append("late_step_context→-1")
|
| 123 |
+
|
| 124 |
+
cost = self.tier_costs.get(adjusted_tier, 0.15)
|
| 125 |
+
self.budget_remaining -= cost
|
| 126 |
+
self.total_spent += cost
|
| 127 |
+
|
| 128 |
+
model_ids = {1:"tiny-local-3b",2:"cheap-cloud-8b",3:"medium-70b",
|
| 129 |
+
4:"frontier-latest",5:"specialist-expert"}
|
| 130 |
+
|
| 131 |
+
return StepRoutingDecision(
|
| 132 |
+
step_type=step_type, base_tier=base_tier,
|
| 133 |
+
adjusted_tier=adjusted_tier,
|
| 134 |
+
reason="; ".join(reasons) if reasons else "default",
|
| 135 |
+
cost_estimate=cost, model_id=model_ids.get(adjusted_tier, "medium-70b"),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def reset(self):
|
| 139 |
+
self.step_history = []
|
| 140 |
+
self.budget_remaining = self.max_budget
|
| 141 |
+
self.total_spent = 0.0
|