vLLM-based route_text(), uv install guide, complete runnable examples
Browse files
README.md
CHANGED
|
@@ -33,17 +33,22 @@ Official leaderboard results on 8,400 queries:
|
|
| 33 |
|
| 34 |
### Installation
|
| 35 |
|
|
|
|
|
|
|
| 36 |
```bash
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
```
|
| 39 |
|
| 40 |
-
### Complete Example
|
| 41 |
|
| 42 |
```python
|
| 43 |
from huggingface_hub import snapshot_download
|
| 44 |
-
import sys
|
| 45 |
-
import numpy as np
|
| 46 |
-
from transformers import AutoModel, AutoTokenizer
|
| 47 |
|
| 48 |
# 1. Download router
|
| 49 |
path = snapshot_download("JiaqiXue/r2-router")
|
|
@@ -54,48 +59,54 @@ from router import R2Router
|
|
| 54 |
# 2. Load pre-trained KNN checkpoints
|
| 55 |
router = R2Router.from_pretrained(path)
|
| 56 |
|
| 57 |
-
# 3.
|
| 58 |
-
|
| 59 |
-
model = AutoModel.from_pretrained("Qwen/Qwen3-0.6B")
|
| 60 |
-
|
| 61 |
-
query = "What is the capital of France?"
|
| 62 |
-
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
|
| 63 |
-
with torch.no_grad():
|
| 64 |
-
output = model(**inputs)
|
| 65 |
-
embedding = output.last_hidden_state.mean(dim=1).squeeze().numpy()
|
| 66 |
-
|
| 67 |
-
# 4. Route!
|
| 68 |
-
result = router.route(embedding)
|
| 69 |
print(f"Model: {result['model_full_name']}")
|
| 70 |
print(f"Token Budget: {result['token_limit']}")
|
| 71 |
print(f"Predicted Quality: {result['predicted_quality']:.3f}")
|
| 72 |
```
|
| 73 |
|
| 74 |
-
|
| 75 |
|
| 76 |
```python
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
|
|
|
| 82 |
|
|
|
|
|
|
|
| 83 |
from router import R2Router
|
| 84 |
|
| 85 |
-
|
| 86 |
-
router = R2Router.from_training_data(path, k=80)
|
| 87 |
-
```
|
| 88 |
|
| 89 |
-
#
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
|
| 92 |
|
| 93 |
```python
|
| 94 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
embedding = outputs[0].outputs.embedding
|
| 99 |
```
|
| 100 |
|
| 101 |
## Architecture
|
|
@@ -155,7 +166,7 @@ Output: (model_name, token_budget)
|
|
| 155 |
|
| 156 |
```
|
| 157 |
config.json # Router configuration (models, budgets, prices, hyperparams)
|
| 158 |
-
router.py # Self-contained inference code
|
| 159 |
training_data/
|
| 160 |
embeddings.npy # Sub_10 training embeddings (809 x 1024)
|
| 161 |
labels.json # Per-(model, budget) accuracy & token labels
|
|
@@ -164,10 +175,13 @@ checkpoints/
|
|
| 164 |
token_knn_*.joblib # Pre-fitted KNN token predictors (6 total)
|
| 165 |
```
|
| 166 |
|
| 167 |
-
###
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
## Training Details
|
| 173 |
|
|
|
|
| 33 |
|
| 34 |
### Installation
|
| 35 |
|
| 36 |
+
We recommend using [uv](https://docs.astral.sh/uv/) for fast, reliable environment setup:
|
| 37 |
+
|
| 38 |
```bash
|
| 39 |
+
# Install uv (if not already installed)
|
| 40 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 41 |
+
|
| 42 |
+
# Create environment and install dependencies
|
| 43 |
+
uv venv .venv && source .venv/bin/activate
|
| 44 |
+
uv pip install scikit-learn numpy joblib huggingface_hub vllm
|
| 45 |
```
|
| 46 |
|
| 47 |
+
### Complete Example (GPU)
|
| 48 |
|
| 49 |
```python
|
| 50 |
from huggingface_hub import snapshot_download
|
| 51 |
+
import sys
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# 1. Download router
|
| 54 |
path = snapshot_download("JiaqiXue/r2-router")
|
|
|
|
| 59 |
# 2. Load pre-trained KNN checkpoints
|
| 60 |
router = R2Router.from_pretrained(path)
|
| 61 |
|
| 62 |
+
# 3. Route a query (auto-embeds with Qwen3-0.6B via vLLM)
|
| 63 |
+
result = router.route_text("What is the capital of France?")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
print(f"Model: {result['model_full_name']}")
|
| 65 |
print(f"Token Budget: {result['token_limit']}")
|
| 66 |
print(f"Predicted Quality: {result['predicted_quality']:.3f}")
|
| 67 |
```
|
| 68 |
|
| 69 |
+
`route_text()` automatically loads Qwen3-0.6B via vLLM on first call and caches it. Batch routing is also supported:
|
| 70 |
|
| 71 |
```python
|
| 72 |
+
queries = [
|
| 73 |
+
"What is the capital of France?",
|
| 74 |
+
"Write a Python function to sort a list",
|
| 75 |
+
"Translate 'hello' to Japanese",
|
| 76 |
+
]
|
| 77 |
+
results = router.route_text(queries)
|
| 78 |
+
for q, r in zip(queries, results):
|
| 79 |
+
print(f"{q[:40]:40s} -> {r['model']} (budget={r['token_limit']})")
|
| 80 |
+
```
|
| 81 |
|
| 82 |
+
### CPU-Only (No GPU)
|
| 83 |
+
|
| 84 |
+
If you don't have a GPU, provide pre-computed embeddings directly:
|
| 85 |
|
| 86 |
+
```python
|
| 87 |
+
import numpy as np
|
| 88 |
from router import R2Router
|
| 89 |
|
| 90 |
+
router = R2Router.from_pretrained(path)
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
# Your own 1024-dim embedding (e.g., from an API or pre-computed)
|
| 93 |
+
embedding = np.random.randn(1024) # replace with real embedding
|
| 94 |
+
result = router.route(embedding)
|
| 95 |
+
```
|
| 96 |
|
| 97 |
+
### Train from Scratch
|
| 98 |
|
| 99 |
```python
|
| 100 |
+
from huggingface_hub import snapshot_download
|
| 101 |
+
import sys
|
| 102 |
+
|
| 103 |
+
path = snapshot_download("JiaqiXue/r2-router")
|
| 104 |
+
sys.path.insert(0, path)
|
| 105 |
+
|
| 106 |
+
from router import R2Router
|
| 107 |
|
| 108 |
+
# Train KNN with custom hyperparameters
|
| 109 |
+
router = R2Router.from_training_data(path, k=80, lambda_val=0.999)
|
|
|
|
| 110 |
```
|
| 111 |
|
| 112 |
## Architecture
|
|
|
|
| 166 |
|
| 167 |
```
|
| 168 |
config.json # Router configuration (models, budgets, prices, hyperparams)
|
| 169 |
+
router.py # Self-contained inference code (embed + route)
|
| 170 |
training_data/
|
| 171 |
embeddings.npy # Sub_10 training embeddings (809 x 1024)
|
| 172 |
labels.json # Per-(model, budget) accuracy & token labels
|
|
|
|
| 175 |
token_knn_*.joblib # Pre-fitted KNN token predictors (6 total)
|
| 176 |
```
|
| 177 |
|
| 178 |
+
### Three Ways to Use
|
| 179 |
|
| 180 |
+
| Method | GPU? | Description |
|
| 181 |
+
|--------|------|-------------|
|
| 182 |
+
| `router.route_text(query)` | Yes | End-to-end: auto-embeds with vLLM, then routes |
|
| 183 |
+
| `router.route(embedding)` | No | Route from pre-computed 1024-dim embedding |
|
| 184 |
+
| `R2Router.from_training_data(path)` | No | Train your own KNN with custom hyperparameters |
|
| 185 |
|
| 186 |
## Training Details
|
| 187 |
|
router.py
CHANGED
|
@@ -6,11 +6,17 @@ pair by predicting per-query quality and cost using KNN.
|
|
| 6 |
|
| 7 |
Usage:
|
| 8 |
from router import R2Router
|
| 9 |
-
router = R2Router.from_pretrained("jqxue1999/r2-router")
|
| 10 |
-
result = router.route(embedding) # embedding: np.ndarray (1024,)
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
import os
|
|
@@ -45,6 +51,7 @@ class R2Router:
|
|
| 45 |
self.model_names = model_names # {short_name: full_name}
|
| 46 |
self.budgets = budgets # {budget_name: token_limit}
|
| 47 |
self.lambda_val = lambda_val
|
|
|
|
| 48 |
|
| 49 |
@classmethod
|
| 50 |
def from_pretrained(cls, path: str, lambda_val: float = 0.999) -> "R2Router":
|
|
@@ -52,10 +59,9 @@ class R2Router:
|
|
| 52 |
Load pre-trained KNN checkpoints.
|
| 53 |
|
| 54 |
Args:
|
| 55 |
-
path: Local directory or HuggingFace repo ID (e.g., "
|
| 56 |
lambda_val: Cost-accuracy tradeoff (higher = more cost-sensitive)
|
| 57 |
"""
|
| 58 |
-
# If HF repo ID, download first
|
| 59 |
if not os.path.isdir(path):
|
| 60 |
path = cls._download_from_hf(path)
|
| 61 |
|
|
@@ -138,7 +144,6 @@ class R2Router:
|
|
| 138 |
knn.fit(X_train[valid], acc[valid])
|
| 139 |
quality_knns[model_name][budget_name] = knn
|
| 140 |
|
| 141 |
-
# Token predictor (use concise budget's output_tokens)
|
| 142 |
if "concise" in model_labels and "output_tokens" in model_labels["concise"]:
|
| 143 |
tok = np.array([x if x is not None else np.nan for x in model_labels["concise"]["output_tokens"]])
|
| 144 |
valid = ~np.isnan(tok)
|
|
@@ -181,6 +186,57 @@ class R2Router:
|
|
| 181 |
"Install with: pip install huggingface_hub"
|
| 182 |
)
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
def route(
|
| 185 |
self,
|
| 186 |
embedding: np.ndarray,
|
|
@@ -206,7 +262,6 @@ class R2Router:
|
|
| 206 |
for mn in self.quality_knns:
|
| 207 |
price = self.model_prices.get(mn, 0)
|
| 208 |
|
| 209 |
-
# Predict output tokens
|
| 210 |
if mn in self.token_knns:
|
| 211 |
tok = max(1.0, float(self.token_knns[mn].predict(embedding)[0]))
|
| 212 |
else:
|
|
|
|
| 6 |
|
| 7 |
Usage:
|
| 8 |
from router import R2Router
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
router = R2Router.from_pretrained(path)
|
| 11 |
+
|
| 12 |
+
# Option 1: Route from text (auto-embeds with vLLM)
|
| 13 |
+
result = router.route_text("What is the capital of France?")
|
| 14 |
+
|
| 15 |
+
# Option 2: Route from pre-computed embedding
|
| 16 |
+
result = router.route(embedding) # np.ndarray (1024,)
|
| 17 |
+
|
| 18 |
+
# Option 3: Train from scratch
|
| 19 |
+
router = R2Router.from_training_data(path, k=80)
|
| 20 |
"""
|
| 21 |
|
| 22 |
import os
|
|
|
|
| 51 |
self.model_names = model_names # {short_name: full_name}
|
| 52 |
self.budgets = budgets # {budget_name: token_limit}
|
| 53 |
self.lambda_val = lambda_val
|
| 54 |
+
self._embedder = None
|
| 55 |
|
| 56 |
@classmethod
|
| 57 |
def from_pretrained(cls, path: str, lambda_val: float = 0.999) -> "R2Router":
|
|
|
|
| 59 |
Load pre-trained KNN checkpoints.
|
| 60 |
|
| 61 |
Args:
|
| 62 |
+
path: Local directory or HuggingFace repo ID (e.g., "JiaqiXue/r2-router")
|
| 63 |
lambda_val: Cost-accuracy tradeoff (higher = more cost-sensitive)
|
| 64 |
"""
|
|
|
|
| 65 |
if not os.path.isdir(path):
|
| 66 |
path = cls._download_from_hf(path)
|
| 67 |
|
|
|
|
| 144 |
knn.fit(X_train[valid], acc[valid])
|
| 145 |
quality_knns[model_name][budget_name] = knn
|
| 146 |
|
|
|
|
| 147 |
if "concise" in model_labels and "output_tokens" in model_labels["concise"]:
|
| 148 |
tok = np.array([x if x is not None else np.nan for x in model_labels["concise"]["output_tokens"]])
|
| 149 |
valid = ~np.isnan(tok)
|
|
|
|
| 186 |
"Install with: pip install huggingface_hub"
|
| 187 |
)
|
| 188 |
|
| 189 |
+
def embed(self, queries: Union[str, List[str]]) -> np.ndarray:
|
| 190 |
+
"""
|
| 191 |
+
Embed queries using Qwen3-0.6B via vLLM (loaded on first call).
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
queries: Single query string or list of queries
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
numpy array of shape (N, 1024)
|
| 198 |
+
"""
|
| 199 |
+
if self._embedder is None:
|
| 200 |
+
try:
|
| 201 |
+
from vllm import LLM
|
| 202 |
+
except ImportError:
|
| 203 |
+
raise ImportError(
|
| 204 |
+
"vLLM is required for text embedding. "
|
| 205 |
+
"Install with: uv pip install vllm"
|
| 206 |
+
)
|
| 207 |
+
self._embedder = LLM(
|
| 208 |
+
model="Qwen/Qwen3-0.6B",
|
| 209 |
+
runner="pooling",
|
| 210 |
+
trust_remote_code=True,
|
| 211 |
+
dtype="half",
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if isinstance(queries, str):
|
| 215 |
+
queries = [queries]
|
| 216 |
+
|
| 217 |
+
outputs = self._embedder.embed(queries)
|
| 218 |
+
return np.array([o.outputs.embedding for o in outputs])
|
| 219 |
+
|
| 220 |
+
def route_text(
|
| 221 |
+
self,
|
| 222 |
+
query: Union[str, List[str]],
|
| 223 |
+
lambda_val: Optional[float] = None,
|
| 224 |
+
) -> Union[Dict, List[Dict]]:
|
| 225 |
+
"""
|
| 226 |
+
Route text query(ies) end-to-end: embed with Qwen3-0.6B, then route.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
query: Single query string or list of queries
|
| 230 |
+
lambda_val: Override default lambda
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Routing decision dict (single) or list of dicts (batch)
|
| 234 |
+
"""
|
| 235 |
+
embeddings = self.embed(query)
|
| 236 |
+
if isinstance(query, str):
|
| 237 |
+
return self.route(embeddings[0], lambda_val)
|
| 238 |
+
return self.route_batch(embeddings, lambda_val)
|
| 239 |
+
|
| 240 |
def route(
|
| 241 |
self,
|
| 242 |
embedding: np.ndarray,
|
|
|
|
| 262 |
for mn in self.quality_knns:
|
| 263 |
price = self.model_prices.get(mn, 0)
|
| 264 |
|
|
|
|
| 265 |
if mn in self.token_knns:
|
| 266 |
tok = max(1.0, float(self.token_knns[mn].predict(embedding)[0]))
|
| 267 |
else:
|