| """Per-Step Router: Route each agent step, not just the task. |
| |
| Key insight from BAAR paper: Tasks get harder mid-execution. |
| A task that starts easy (search codebase) may end hard (patch critical bug). |
| The router should re-evaluate at each step based on: |
| - Current step type (search, read, edit, execute, verify) |
| - Results so far (errors found, tools failed) |
| - Remaining budget |
| - Confidence in current trajectory |
| |
| Step types and their typical tier requirements: |
| - search/list: tier 1-2 (easy, pattern matching) |
| - read/understand: tier 2-3 (moderate comprehension) |
| - edit/patch: tier 3-4 (requires deep understanding) |
| - execute/test: tier 2-3 (pattern matching on errors) |
| - verify/review: tier 3-5 (requires expert judgment) |
| - plan/orchestrate: tier 4-5 (strategic thinking) |
| """ |
| from typing import Dict, List, Optional, Tuple |
| from dataclasses import dataclass |
| from enum import Enum |
|
|
| class StepType(Enum): |
| SEARCH = "search" |
| READ = "read" |
| EDIT = "edit" |
| EXECUTE = "execute" |
| VERIFY = "verify" |
| PLAN = "plan" |
| COMMUNICATE = "communicate" |
| UNKNOWN = "unknown" |
|
|
| |
| STEP_TYPE_TIER = { |
| StepType.SEARCH: 2, |
| StepType.READ: 2, |
| StepType.EDIT: 3, |
| StepType.EXECUTE: 2, |
| StepType.VERIFY: 3, |
| StepType.PLAN: 4, |
| StepType.COMMUNICATE: 2, |
| StepType.UNKNOWN: 3, |
| } |
|
|
| |
| STEP_PATTERNS = { |
| StepType.SEARCH: ["search","find","grep","list","locate","look up","query"], |
| StepType.READ: ["read","cat","show","display","view","inspect","open","check"], |
| StepType.EDIT: ["edit","write","patch","modify","change","fix","update","insert","delete","replace"], |
| StepType.EXECUTE: ["run","execute","test","compile","build","bash","shell","python"], |
| StepType.VERIFY: ["verify","validate","check","review","confirm","assert","lint"], |
| StepType.PLAN: ["plan","decide","choose","strategy","orchestrate","coordinate"], |
| StepType.COMMUNICATE: ["ask","tell","inform","report","summarize","explain"], |
| } |
|
|
| @dataclass |
| class StepRoutingDecision: |
| step_type: StepType |
| base_tier: int |
| adjusted_tier: int |
| reason: str |
| cost_estimate: float |
| model_id: str |
|
|
| class PerStepRouter: |
| def __init__(self, max_budget: float = 5.0, tier_costs: Dict[int,float] = None): |
| self.max_budget = max_budget |
| self.tier_costs = tier_costs or {1:0.01,2:0.05,3:0.15,4:0.30,5:0.50} |
| self.step_history: List[Dict] = [] |
| self.budget_remaining: float = max_budget |
| self.total_spent: float = 0.0 |
| |
| def classify_step(self, action: str) -> StepType: |
| r = action.lower() |
| for stype, patterns in STEP_PATTERNS.items(): |
| if any(p in r for p in patterns): |
| return stype |
| return StepType.UNKNOWN |
| |
| def route_step(self, action: str, step_num: int, |
| has_prior_failures: bool = False, |
| num_errors_seen: int = 0, |
| task_risk: str = "medium") -> StepRoutingDecision: |
| step_type = self.classify_step(action) |
| base_tier = STEP_TYPE_TIER.get(step_type, 3) |
| adjusted_tier = base_tier |
| reasons = [] |
| |
| |
| if has_prior_failures: |
| adjusted_tier = min(adjusted_tier + 1, 5) |
| reasons.append("prior_failures→+1") |
| |
| |
| if num_errors_seen >= 3: |
| adjusted_tier = min(adjusted_tier + 1, 5) |
| reasons.append("many_errors→+1") |
| |
| |
| if task_risk == "critical" and step_type in (StepType.EDIT, StepType.VERIFY): |
| adjusted_tier = max(adjusted_tier, 4) |
| reasons.append("critical_risk→floor4") |
| elif task_risk == "high" and step_type == StepType.EDIT: |
| adjusted_tier = max(adjusted_tier, 3) |
| reasons.append("high_risk→floor3") |
| |
| |
| step_cost = self.tier_costs.get(adjusted_tier, 0.15) |
| if self.budget_remaining < step_cost * 2: |
| |
| for t in range(1, adjusted_tier): |
| if self.tier_costs.get(t, 1.0) * 2 <= self.budget_remaining: |
| if t >= STEP_TYPE_TIER.get(step_type, 2) - 1: |
| adjusted_tier = t |
| reasons.append("budget_constraint→downgrade") |
| break |
| |
| |
| if step_num > 5 and step_type in (StepType.READ, StepType.SEARCH): |
| |
| adjusted_tier = max(adjusted_tier - 1, 1) |
| reasons.append("late_step_context→-1") |
| |
| cost = self.tier_costs.get(adjusted_tier, 0.15) |
| self.budget_remaining -= cost |
| self.total_spent += cost |
| |
| model_ids = {1:"tiny-local-3b",2:"cheap-cloud-8b",3:"medium-70b", |
| 4:"frontier-latest",5:"specialist-expert"} |
| |
| return StepRoutingDecision( |
| step_type=step_type, base_tier=base_tier, |
| adjusted_tier=adjusted_tier, |
| reason="; ".join(reasons) if reasons else "default", |
| cost_estimate=cost, model_id=model_ids.get(adjusted_tier, "medium-70b"), |
| ) |
| |
| def reset(self): |
| self.step_history = [] |
| self.budget_remaining = self.max_budget |
| self.total_spent = 0.0 |
|
|