File size: 6,706 Bytes
80a4e8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
routing.py β€” SLM-native LLM call router with cost homeostasis.

Routes tasks to the smallest capable model. Local-first by default.
Enforces cost, latency, and token budgets as hard constraints.

Complexity classification:
  simple    β†’ single SLM call (summarize, answer simple Q)
  moderate  β†’ sequential chain (plan β†’ execute)
  complex   β†’ parallel specialists (research + code + review)
  critical  β†’ specialists + critic ensemble + optional HITL

Router decisions are logged and reproducible.
"""
from __future__ import annotations

import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any

from purpose_agent.llm_backend import LLMBackend

logger = logging.getLogger(__name__)


class TaskComplexity(str, Enum):
    SIMPLE = "simple"
    MODERATE = "moderate"
    COMPLEX = "complex"
    CRITICAL = "critical"


@dataclass
class RoutingPolicy:
    """Policy governing model selection and cost control."""
    prefer_local: bool = True
    max_cost_per_task_usd: float = 0.10
    max_latency_per_call_s: float = 30.0
    max_tokens_per_task: int = 10000
    allow_cloud_fallback: bool = True
    fallback_model: str = ""
    local_model: str = "ollama:qwen3:1.7b"
    cloud_model: str = "openrouter:meta-llama/llama-3.3-70b-instruct"


@dataclass
class ModelOption:
    """A model available for routing."""
    spec: str                       # e.g. "ollama:qwen3:1.7b"
    is_local: bool = True
    cost_per_1k_tokens: float = 0.0  # $0 for local
    avg_latency_s: float = 1.0
    max_context: int = 32768
    capabilities: list[str] = field(default_factory=list)  # ["code","reasoning","general"]


@dataclass
class RoutingDecision:
    """Recorded decision from the router."""
    task_summary: str
    complexity: TaskComplexity
    selected_model: str
    reason: str
    timestamp: float = field(default_factory=time.time)
    estimated_cost: float = 0.0


# Keyword-based complexity heuristics
_COMPLEX_KEYWORDS = {"research", "analyze", "compare", "design", "architect", "security", "audit"}
_CRITICAL_KEYWORDS = {"deploy", "production", "delete", "admin", "payment", "credential", "secret"}
_SIMPLE_KEYWORDS = {"summarize", "translate", "hello", "what is", "define", "explain"}


class TaskComplexityClassifier:
    """Classifies task complexity from the purpose description."""

    def classify(self, purpose: str) -> TaskComplexity:
        words = set(purpose.lower().split())

        if words & _CRITICAL_KEYWORDS:
            return TaskComplexity.CRITICAL
        if words & _COMPLEX_KEYWORDS:
            return TaskComplexity.COMPLEX
        if words & _SIMPLE_KEYWORDS:
            return TaskComplexity.SIMPLE
        # Default: moderate for anything with multiple sentences or code-related
        if len(purpose) > 100 or "code" in purpose.lower() or "function" in purpose.lower():
            return TaskComplexity.MODERATE
        return TaskComplexity.SIMPLE


class ModelSelector:
    """
    Selects the best model for a task given complexity and policy.
    
    Rules:
      1. Local-first (if policy.prefer_local and local model available)
      2. Smallest capable model (don't use 70B for "say hello")
      3. Respect cost/latency budgets
      4. Fallback to cloud only when policy allows and local fails
    """

    def __init__(self, models: list[ModelOption] | None = None, policy: RoutingPolicy | None = None):
        self.models = models or []
        self.policy = policy or RoutingPolicy()

    def select(self, complexity: TaskComplexity) -> str:
        """Select the best model spec for given complexity."""
        # Filter by policy
        candidates = list(self.models)

        if self.policy.prefer_local:
            local = [m for m in candidates if m.is_local]
            if local:
                candidates = local

        # For simple tasks, prefer smallest/cheapest
        if complexity == TaskComplexity.SIMPLE:
            candidates.sort(key=lambda m: m.cost_per_1k_tokens)
            if candidates:
                return candidates[0].spec

        # For complex/critical, prefer most capable
        if complexity in (TaskComplexity.COMPLEX, TaskComplexity.CRITICAL):
            # Prefer cloud models with more capability
            if self.policy.allow_cloud_fallback:
                return self.policy.cloud_model
            capable = [m for m in candidates if "reasoning" in m.capabilities or "code" in m.capabilities]
            if capable:
                return capable[0].spec

        # Default: local model
        return self.policy.local_model


class LLMCallRouter:
    """
    Main router: classifies task β†’ selects model β†’ logs decision.
    
    Usage:
        router = LLMCallRouter(policy=RoutingPolicy(prefer_local=True))
        model_spec = router.route("Write a fibonacci function")
        # β†’ "ollama:qwen3:1.7b" (local, code task, moderate complexity)
        
        model_spec = router.route("Audit production deployment for security vulnerabilities")
        # β†’ cloud model (critical task, needs strong reasoning)
    """

    def __init__(self, policy: RoutingPolicy | None = None, models: list[ModelOption] | None = None):
        self.policy = policy or RoutingPolicy()
        self.classifier = TaskComplexityClassifier()
        self.selector = ModelSelector(models or [], self.policy)
        self._decisions: list[RoutingDecision] = []
        self._total_cost = 0.0

    def route(self, task: str) -> str:
        """Route a task to the best model. Returns model spec string."""
        complexity = self.classifier.classify(task)
        selected = self.selector.select(complexity)

        # Budget check
        if self._total_cost >= self.policy.max_cost_per_task_usd:
            # Over budget: force local
            selected = self.policy.local_model
            reason = "budget_exceeded: forced local"
        else:
            reason = f"complexity={complexity.value}"

        decision = RoutingDecision(
            task_summary=task[:80],
            complexity=complexity,
            selected_model=selected,
            reason=reason,
        )
        self._decisions.append(decision)
        logger.info(f"Router: {complexity.value} β†’ {selected} ({reason})")
        return selected

    def record_cost(self, cost_usd: float) -> None:
        """Record cost of a completed call for budget tracking."""
        self._total_cost += cost_usd

    @property
    def total_cost(self) -> float:
        return self._total_cost

    @property
    def decisions(self) -> list[RoutingDecision]:
        return self._decisions

    def reset_budget(self) -> None:
        self._total_cost = 0.0