File size: 12,468 Bytes
bc1c255
 
 
 
 
 
 
 
 
52e5f26
5309389
52e5f26
5309389
52e5f26
603d970
52e5f26
5309389
 
52e5f26
5309389
bc1c255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52e5f26
bc1c255
 
 
 
 
 
 
52e5f26
5309389
bc1c255
 
52e5f26
 
 
 
 
 
bc1c255
 
 
 
5309389
bc1c255
52e5f26
 
bc1c255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52e5f26
bc1c255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3858b5d
9e08fe1
bc1c255
 
9e08fe1
 
bc1c255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e08fe1
bc1c255
 
 
 
 
 
 
 
 
 
 
 
9e08fe1
 
3858b5d
bc1c255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5309389
 
52e5f26
 
 
 
5309389
 
 
 
 
 
 
52e5f26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5309389
 
 
 
 
52e5f26
 
 
5309389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc1c255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
"""
R2-Router: LLM Router with Joint Model-Budget Optimization

Self-contained inference module. Routes queries to the optimal (model, token_budget)
pair by predicting per-query quality and cost using KNN.

Usage:
    from router import R2Router

    # Option A: Local vLLM (loads Qwen3-0.6B on first call)
    router = R2Router.from_pretrained(path)
    result = router.route_text("What is the capital of France?")

    # Option B: Remote vLLM server (no local GPU needed for embedding)
    #   Start server: vllm serve Qwen/Qwen3-0.6B --runner pooling
    router = R2Router.from_pretrained(path, embed_url="http://localhost:8000")
    result = router.route_text("What is the capital of France?")

    # Option C: Pre-computed embedding
    result = router.route(embedding)  # np.ndarray (1024,)
"""

import os
import json
import numpy as np
import joblib
from typing import Dict, List, Optional, Union
from sklearn.neighbors import KNeighborsRegressor


class R2Router:
    """
    R2-Router: Routes queries to optimal (LLM, token_budget) pair.

    Uses KNN to predict quality for each (model, budget) combination,
    then selects the pair that maximizes:
        risk = (1 - lambda) * quality - lambda * tokens * price / 1e6
    """

    def __init__(
        self,
        quality_knns: Dict[str, Dict[str, KNeighborsRegressor]],
        token_knns: Dict[str, KNeighborsRegressor],
        model_prices: Dict[str, float],
        model_names: Dict[str, str],
        budgets: Dict[str, int],
        lambda_val: float = 0.999,
        embed_url: Optional[str] = None,
    ):
        self.quality_knns = quality_knns  # {model: {budget: KNN}}
        self.token_knns = token_knns      # {model: KNN}
        self.model_prices = model_prices  # {model: price_per_million_output_tokens}
        self.model_names = model_names    # {short_name: full_name}
        self.budgets = budgets            # {budget_name: token_limit}
        self.lambda_val = lambda_val
        self.embed_url = embed_url        # vLLM server URL, e.g. "http://localhost:8000"
        self._embedder = None

    @classmethod
    def from_pretrained(
        cls,
        path: str,
        lambda_val: float = 0.999,
        embed_url: Optional[str] = None,
    ) -> "R2Router":
        """
        Load pre-trained KNN checkpoints.

        Args:
            path: Local directory or HuggingFace repo ID (e.g., "JiaqiXue/r2-router")
            lambda_val: Cost-accuracy tradeoff (higher = more cost-sensitive)
            embed_url: vLLM server URL for embedding (e.g., "http://localhost:8000").
                       If None, loads Qwen3-0.6B locally on first route_text() call.
        """
        if not os.path.isdir(path):
            path = cls._download_from_hf(path)

        with open(os.path.join(path, "config.json")) as f:
            config = json.load(f)

        ckpt_dir = os.path.join(path, "checkpoints")
        quality_knns = {}
        token_knns = {}

        for model_name in config["models"]:
            quality_knns[model_name] = {}
            for budget_name in config["budgets"]:
                ckpt_path = os.path.join(ckpt_dir, f"quality_knn_{model_name}_{budget_name}.joblib")
                if os.path.exists(ckpt_path):
                    quality_knns[model_name][budget_name] = joblib.load(ckpt_path)

            tok_path = os.path.join(ckpt_dir, f"token_knn_{model_name}.joblib")
            if os.path.exists(tok_path):
                token_knns[model_name] = joblib.load(tok_path)

        model_prices = {
            mn: cfg["output_price_per_million"]
            for mn, cfg in config["models"].items()
        }
        model_names = {
            mn: cfg["full_name"]
            for mn, cfg in config["models"].items()
        }

        return cls(
            quality_knns=quality_knns,
            token_knns=token_knns,
            model_prices=model_prices,
            model_names=model_names,
            budgets=config["budgets"],
            lambda_val=lambda_val,
            embed_url=embed_url,
        )

    @classmethod
    def from_training_data(
        cls,
        path: str,
        k: int = 80,
        lambda_val: float = 0.999,
    ) -> "R2Router":
        """
        Train KNN from scratch using the provided training data.

        Args:
            path: Local directory or HuggingFace repo ID
            k: Number of KNN neighbors
            lambda_val: Cost-accuracy tradeoff
        """
        if not os.path.isdir(path):
            path = cls._download_from_hf(path)

        with open(os.path.join(path, "config.json")) as f:
            config = json.load(f)

        X_train = np.load(os.path.join(path, "training_data", "embeddings.npy"))
        with open(os.path.join(path, "training_data", "labels.json")) as f:
            labels = json.load(f)

        print(f"Training router on {len(X_train)} samples (k={k})...")

        quality_knns = {}
        token_knns = {}
        n_quality = 0
        n_token = 0

        for model_name, model_labels in labels.items():
            quality_knns[model_name] = {}
            for budget_name, bdata in model_labels.items():
                acc = np.array([x if x is not None else np.nan for x in bdata["accuracy"]])
                valid = ~np.isnan(acc)
                if valid.sum() < 3:
                    continue
                knn = KNeighborsRegressor(
                    n_neighbors=min(k, int(valid.sum()) - 1),
                    metric="cosine",
                    weights="distance",
                )
                knn.fit(X_train[valid], acc[valid])
                quality_knns[model_name][budget_name] = knn
                n_quality += 1

            if "concise" in model_labels and "output_tokens" in model_labels["concise"]:
                tok = np.array([x if x is not None else np.nan for x in model_labels["concise"]["output_tokens"]])
                valid = ~np.isnan(tok)
                if valid.sum() >= 3:
                    tknn = KNeighborsRegressor(
                        n_neighbors=min(k, int(valid.sum()) - 1),
                        metric="cosine",
                        weights="distance",
                    )
                    tknn.fit(X_train[valid], tok[valid])
                    token_knns[model_name] = tknn
                    n_token += 1

        print(f"Trained {n_quality} quality predictors + {n_token} token predictors for {len(quality_knns)} models.")

        model_prices = {
            mn: cfg["output_price_per_million"]
            for mn, cfg in config["models"].items()
        }
        model_names = {
            mn: cfg["full_name"]
            for mn, cfg in config["models"].items()
        }

        return cls(
            quality_knns=quality_knns,
            token_knns=token_knns,
            model_prices=model_prices,
            model_names=model_names,
            budgets=config["budgets"],
            lambda_val=lambda_val,
        )

    @staticmethod
    def _download_from_hf(repo_id: str) -> str:
        """Download model from Hugging Face Hub."""
        try:
            from huggingface_hub import snapshot_download
            return snapshot_download(repo_id)
        except ImportError:
            raise ImportError(
                "huggingface_hub is required to download from HF. "
                "Install with: pip install huggingface_hub"
            )

    def embed(self, queries: Union[str, List[str]]) -> np.ndarray:
        """
        Embed queries using Qwen3-0.6B.

        If embed_url is set, uses a remote vLLM server (OpenAI-compatible API).
        Otherwise, loads Qwen3-0.6B locally via vLLM (on first call).

        Args:
            queries: Single query string or list of queries

        Returns:
            numpy array of shape (N, 1024)
        """
        if isinstance(queries, str):
            queries = [queries]

        if self.embed_url:
            return self._embed_remote(queries)
        return self._embed_local(queries)

    def _embed_remote(self, queries: List[str]) -> np.ndarray:
        """Embed via a running vLLM server (OpenAI-compatible embeddings API)."""
        import urllib.request

        url = self.embed_url.rstrip("/") + "/v1/embeddings"
        payload = json.dumps({
            "model": "Qwen/Qwen3-0.6B",
            "input": queries,
        }).encode()

        req = urllib.request.Request(
            url, data=payload,
            headers={"Content-Type": "application/json"},
        )
        with urllib.request.urlopen(req) as resp:
            result = json.loads(resp.read())

        embeddings = [item["embedding"] for item in sorted(result["data"], key=lambda x: x["index"])]
        return np.array(embeddings)

    def _embed_local(self, queries: List[str]) -> np.ndarray:
        """Embed by loading Qwen3-0.6B locally via vLLM."""
        if self._embedder is None:
            try:
                from vllm import LLM
            except ImportError:
                raise ImportError(
                    "vLLM is required for local embedding. "
                    "Install with: uv pip install vllm\n"
                    "Or start a vLLM server and pass embed_url to from_pretrained()."
                )
            self._embedder = LLM(
                model="Qwen/Qwen3-0.6B",
                runner="pooling",
                trust_remote_code=True,
                dtype="half",
            )

        outputs = self._embedder.embed(queries)
        return np.array([o.outputs.embedding for o in outputs])

    def route_text(
        self,
        query: Union[str, List[str]],
        lambda_val: Optional[float] = None,
    ) -> Union[Dict, List[Dict]]:
        """
        Route text query(ies) end-to-end: embed with Qwen3-0.6B, then route.

        Args:
            query: Single query string or list of queries
            lambda_val: Override default lambda

        Returns:
            Routing decision dict (single) or list of dicts (batch)
        """
        embeddings = self.embed(query)
        if isinstance(query, str):
            return self.route(embeddings[0], lambda_val)
        return self.route_batch(embeddings, lambda_val)

    def route(
        self,
        embedding: np.ndarray,
        lambda_val: Optional[float] = None,
    ) -> Dict:
        """
        Route a query to the optimal (model, token_budget) pair.

        Args:
            embedding: Query embedding vector, shape (1024,) or (1, 1024)
            lambda_val: Override default lambda (higher = more cost-sensitive)

        Returns:
            Dict with keys: model, model_full_name, budget, token_limit,
                           predicted_quality, predicted_cost, risk, all_options
        """
        if embedding.ndim == 1:
            embedding = embedding.reshape(1, -1)

        lam = lambda_val if lambda_val is not None else self.lambda_val
        all_options = []

        for mn in self.quality_knns:
            price = self.model_prices.get(mn, 0)

            if mn in self.token_knns:
                tok = max(1.0, float(self.token_knns[mn].predict(embedding)[0]))
            else:
                tok = 50.0

            for budget_name, knn in self.quality_knns[mn].items():
                q = float(knn.predict(embedding)[0])
                risk = (1 - lam) * q - lam * tok * price / 1e6

                all_options.append({
                    "model": mn,
                    "model_full_name": self.model_names.get(mn, mn),
                    "budget": budget_name,
                    "token_limit": self.budgets.get(budget_name, budget_name),
                    "predicted_quality": q,
                    "predicted_tokens": tok,
                    "predicted_cost": tok * price / 1e6,
                    "risk": risk,
                })

        if not all_options:
            raise RuntimeError("No valid routing options")

        best = max(all_options, key=lambda x: x["risk"])
        best["all_options"] = all_options
        return best

    def route_batch(
        self,
        embeddings: np.ndarray,
        lambda_val: Optional[float] = None,
    ) -> List[Dict]:
        """
        Route a batch of queries.

        Args:
            embeddings: Query embeddings, shape (N, 1024)
            lambda_val: Override default lambda

        Returns:
            List of routing decisions
        """
        return [self.route(embeddings[i], lambda_val) for i in range(len(embeddings))]