Rohan03 commited on
Commit
80a4e8f
·
verified ·
1 Parent(s): f6a5e41

Sprint 5: routing.py — LLMCallRouter, ModelSelector, cost homeostasis

Browse files
Files changed (1) hide show
  1. purpose_agent/routing.py +194 -0
purpose_agent/routing.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ routing.py — SLM-native LLM call router with cost homeostasis.
3
+
4
+ Routes tasks to the smallest capable model. Local-first by default.
5
+ Enforces cost, latency, and token budgets as hard constraints.
6
+
7
+ Complexity classification:
8
+ simple → single SLM call (summarize, answer simple Q)
9
+ moderate → sequential chain (plan → execute)
10
+ complex → parallel specialists (research + code + review)
11
+ critical → specialists + critic ensemble + optional HITL
12
+
13
+ Router decisions are logged and reproducible.
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ import time
19
+ from dataclasses import dataclass, field
20
+ from enum import Enum
21
+ from typing import Any
22
+
23
+ from purpose_agent.llm_backend import LLMBackend
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class TaskComplexity(str, Enum):
29
+ SIMPLE = "simple"
30
+ MODERATE = "moderate"
31
+ COMPLEX = "complex"
32
+ CRITICAL = "critical"
33
+
34
+
35
+ @dataclass
36
+ class RoutingPolicy:
37
+ """Policy governing model selection and cost control."""
38
+ prefer_local: bool = True
39
+ max_cost_per_task_usd: float = 0.10
40
+ max_latency_per_call_s: float = 30.0
41
+ max_tokens_per_task: int = 10000
42
+ allow_cloud_fallback: bool = True
43
+ fallback_model: str = ""
44
+ local_model: str = "ollama:qwen3:1.7b"
45
+ cloud_model: str = "openrouter:meta-llama/llama-3.3-70b-instruct"
46
+
47
+
48
+ @dataclass
49
+ class ModelOption:
50
+ """A model available for routing."""
51
+ spec: str # e.g. "ollama:qwen3:1.7b"
52
+ is_local: bool = True
53
+ cost_per_1k_tokens: float = 0.0 # $0 for local
54
+ avg_latency_s: float = 1.0
55
+ max_context: int = 32768
56
+ capabilities: list[str] = field(default_factory=list) # ["code","reasoning","general"]
57
+
58
+
59
+ @dataclass
60
+ class RoutingDecision:
61
+ """Recorded decision from the router."""
62
+ task_summary: str
63
+ complexity: TaskComplexity
64
+ selected_model: str
65
+ reason: str
66
+ timestamp: float = field(default_factory=time.time)
67
+ estimated_cost: float = 0.0
68
+
69
+
70
+ # Keyword-based complexity heuristics
71
+ _COMPLEX_KEYWORDS = {"research", "analyze", "compare", "design", "architect", "security", "audit"}
72
+ _CRITICAL_KEYWORDS = {"deploy", "production", "delete", "admin", "payment", "credential", "secret"}
73
+ _SIMPLE_KEYWORDS = {"summarize", "translate", "hello", "what is", "define", "explain"}
74
+
75
+
76
+ class TaskComplexityClassifier:
77
+ """Classifies task complexity from the purpose description."""
78
+
79
+ def classify(self, purpose: str) -> TaskComplexity:
80
+ words = set(purpose.lower().split())
81
+
82
+ if words & _CRITICAL_KEYWORDS:
83
+ return TaskComplexity.CRITICAL
84
+ if words & _COMPLEX_KEYWORDS:
85
+ return TaskComplexity.COMPLEX
86
+ if words & _SIMPLE_KEYWORDS:
87
+ return TaskComplexity.SIMPLE
88
+ # Default: moderate for anything with multiple sentences or code-related
89
+ if len(purpose) > 100 or "code" in purpose.lower() or "function" in purpose.lower():
90
+ return TaskComplexity.MODERATE
91
+ return TaskComplexity.SIMPLE
92
+
93
+
94
+ class ModelSelector:
95
+ """
96
+ Selects the best model for a task given complexity and policy.
97
+
98
+ Rules:
99
+ 1. Local-first (if policy.prefer_local and local model available)
100
+ 2. Smallest capable model (don't use 70B for "say hello")
101
+ 3. Respect cost/latency budgets
102
+ 4. Fallback to cloud only when policy allows and local fails
103
+ """
104
+
105
+ def __init__(self, models: list[ModelOption] | None = None, policy: RoutingPolicy | None = None):
106
+ self.models = models or []
107
+ self.policy = policy or RoutingPolicy()
108
+
109
+ def select(self, complexity: TaskComplexity) -> str:
110
+ """Select the best model spec for given complexity."""
111
+ # Filter by policy
112
+ candidates = list(self.models)
113
+
114
+ if self.policy.prefer_local:
115
+ local = [m for m in candidates if m.is_local]
116
+ if local:
117
+ candidates = local
118
+
119
+ # For simple tasks, prefer smallest/cheapest
120
+ if complexity == TaskComplexity.SIMPLE:
121
+ candidates.sort(key=lambda m: m.cost_per_1k_tokens)
122
+ if candidates:
123
+ return candidates[0].spec
124
+
125
+ # For complex/critical, prefer most capable
126
+ if complexity in (TaskComplexity.COMPLEX, TaskComplexity.CRITICAL):
127
+ # Prefer cloud models with more capability
128
+ if self.policy.allow_cloud_fallback:
129
+ return self.policy.cloud_model
130
+ capable = [m for m in candidates if "reasoning" in m.capabilities or "code" in m.capabilities]
131
+ if capable:
132
+ return capable[0].spec
133
+
134
+ # Default: local model
135
+ return self.policy.local_model
136
+
137
+
138
+ class LLMCallRouter:
139
+ """
140
+ Main router: classifies task → selects model → logs decision.
141
+
142
+ Usage:
143
+ router = LLMCallRouter(policy=RoutingPolicy(prefer_local=True))
144
+ model_spec = router.route("Write a fibonacci function")
145
+ # → "ollama:qwen3:1.7b" (local, code task, moderate complexity)
146
+
147
+ model_spec = router.route("Audit production deployment for security vulnerabilities")
148
+ # → cloud model (critical task, needs strong reasoning)
149
+ """
150
+
151
+ def __init__(self, policy: RoutingPolicy | None = None, models: list[ModelOption] | None = None):
152
+ self.policy = policy or RoutingPolicy()
153
+ self.classifier = TaskComplexityClassifier()
154
+ self.selector = ModelSelector(models or [], self.policy)
155
+ self._decisions: list[RoutingDecision] = []
156
+ self._total_cost = 0.0
157
+
158
+ def route(self, task: str) -> str:
159
+ """Route a task to the best model. Returns model spec string."""
160
+ complexity = self.classifier.classify(task)
161
+ selected = self.selector.select(complexity)
162
+
163
+ # Budget check
164
+ if self._total_cost >= self.policy.max_cost_per_task_usd:
165
+ # Over budget: force local
166
+ selected = self.policy.local_model
167
+ reason = "budget_exceeded: forced local"
168
+ else:
169
+ reason = f"complexity={complexity.value}"
170
+
171
+ decision = RoutingDecision(
172
+ task_summary=task[:80],
173
+ complexity=complexity,
174
+ selected_model=selected,
175
+ reason=reason,
176
+ )
177
+ self._decisions.append(decision)
178
+ logger.info(f"Router: {complexity.value} → {selected} ({reason})")
179
+ return selected
180
+
181
+ def record_cost(self, cost_usd: float) -> None:
182
+ """Record cost of a completed call for budget tracking."""
183
+ self._total_cost += cost_usd
184
+
185
+ @property
186
+ def total_cost(self) -> float:
187
+ return self._total_cost
188
+
189
+ @property
190
+ def decisions(self) -> list[RoutingDecision]:
191
+ return self._decisions
192
+
193
+ def reset_budget(self) -> None:
194
+ self._total_cost = 0.0