JiaqiXue's picture
docs: use generic predictor terminology in training output
3858b5d verified
"""
R2-Router: LLM Router with Joint Model-Budget Optimization
Self-contained inference module. Routes queries to the optimal (model, token_budget)
pair by predicting per-query quality and cost using KNN.
Usage:
from router import R2Router
# Option A: Local vLLM (loads Qwen3-0.6B on first call)
router = R2Router.from_pretrained(path)
result = router.route_text("What is the capital of France?")
# Option B: Remote vLLM server (no local GPU needed for embedding)
# Start server: vllm serve Qwen/Qwen3-0.6B --runner pooling
router = R2Router.from_pretrained(path, embed_url="http://localhost:8000")
result = router.route_text("What is the capital of France?")
# Option C: Pre-computed embedding
result = router.route(embedding) # np.ndarray (1024,)
"""
import os
import json
import numpy as np
import joblib
from typing import Dict, List, Optional, Union
from sklearn.neighbors import KNeighborsRegressor
class R2Router:
"""
R2-Router: Routes queries to optimal (LLM, token_budget) pair.
Uses KNN to predict quality for each (model, budget) combination,
then selects the pair that maximizes:
risk = (1 - lambda) * quality - lambda * tokens * price / 1e6
"""
def __init__(
self,
quality_knns: Dict[str, Dict[str, KNeighborsRegressor]],
token_knns: Dict[str, KNeighborsRegressor],
model_prices: Dict[str, float],
model_names: Dict[str, str],
budgets: Dict[str, int],
lambda_val: float = 0.999,
embed_url: Optional[str] = None,
):
self.quality_knns = quality_knns # {model: {budget: KNN}}
self.token_knns = token_knns # {model: KNN}
self.model_prices = model_prices # {model: price_per_million_output_tokens}
self.model_names = model_names # {short_name: full_name}
self.budgets = budgets # {budget_name: token_limit}
self.lambda_val = lambda_val
self.embed_url = embed_url # vLLM server URL, e.g. "http://localhost:8000"
self._embedder = None
@classmethod
def from_pretrained(
cls,
path: str,
lambda_val: float = 0.999,
embed_url: Optional[str] = None,
) -> "R2Router":
"""
Load pre-trained KNN checkpoints.
Args:
path: Local directory or HuggingFace repo ID (e.g., "JiaqiXue/r2-router")
lambda_val: Cost-accuracy tradeoff (higher = more cost-sensitive)
embed_url: vLLM server URL for embedding (e.g., "http://localhost:8000").
If None, loads Qwen3-0.6B locally on first route_text() call.
"""
if not os.path.isdir(path):
path = cls._download_from_hf(path)
with open(os.path.join(path, "config.json")) as f:
config = json.load(f)
ckpt_dir = os.path.join(path, "checkpoints")
quality_knns = {}
token_knns = {}
for model_name in config["models"]:
quality_knns[model_name] = {}
for budget_name in config["budgets"]:
ckpt_path = os.path.join(ckpt_dir, f"quality_knn_{model_name}_{budget_name}.joblib")
if os.path.exists(ckpt_path):
quality_knns[model_name][budget_name] = joblib.load(ckpt_path)
tok_path = os.path.join(ckpt_dir, f"token_knn_{model_name}.joblib")
if os.path.exists(tok_path):
token_knns[model_name] = joblib.load(tok_path)
model_prices = {
mn: cfg["output_price_per_million"]
for mn, cfg in config["models"].items()
}
model_names = {
mn: cfg["full_name"]
for mn, cfg in config["models"].items()
}
return cls(
quality_knns=quality_knns,
token_knns=token_knns,
model_prices=model_prices,
model_names=model_names,
budgets=config["budgets"],
lambda_val=lambda_val,
embed_url=embed_url,
)
@classmethod
def from_training_data(
cls,
path: str,
k: int = 80,
lambda_val: float = 0.999,
) -> "R2Router":
"""
Train KNN from scratch using the provided training data.
Args:
path: Local directory or HuggingFace repo ID
k: Number of KNN neighbors
lambda_val: Cost-accuracy tradeoff
"""
if not os.path.isdir(path):
path = cls._download_from_hf(path)
with open(os.path.join(path, "config.json")) as f:
config = json.load(f)
X_train = np.load(os.path.join(path, "training_data", "embeddings.npy"))
with open(os.path.join(path, "training_data", "labels.json")) as f:
labels = json.load(f)
print(f"Training router on {len(X_train)} samples (k={k})...")
quality_knns = {}
token_knns = {}
n_quality = 0
n_token = 0
for model_name, model_labels in labels.items():
quality_knns[model_name] = {}
for budget_name, bdata in model_labels.items():
acc = np.array([x if x is not None else np.nan for x in bdata["accuracy"]])
valid = ~np.isnan(acc)
if valid.sum() < 3:
continue
knn = KNeighborsRegressor(
n_neighbors=min(k, int(valid.sum()) - 1),
metric="cosine",
weights="distance",
)
knn.fit(X_train[valid], acc[valid])
quality_knns[model_name][budget_name] = knn
n_quality += 1
if "concise" in model_labels and "output_tokens" in model_labels["concise"]:
tok = np.array([x if x is not None else np.nan for x in model_labels["concise"]["output_tokens"]])
valid = ~np.isnan(tok)
if valid.sum() >= 3:
tknn = KNeighborsRegressor(
n_neighbors=min(k, int(valid.sum()) - 1),
metric="cosine",
weights="distance",
)
tknn.fit(X_train[valid], tok[valid])
token_knns[model_name] = tknn
n_token += 1
print(f"Trained {n_quality} quality predictors + {n_token} token predictors for {len(quality_knns)} models.")
model_prices = {
mn: cfg["output_price_per_million"]
for mn, cfg in config["models"].items()
}
model_names = {
mn: cfg["full_name"]
for mn, cfg in config["models"].items()
}
return cls(
quality_knns=quality_knns,
token_knns=token_knns,
model_prices=model_prices,
model_names=model_names,
budgets=config["budgets"],
lambda_val=lambda_val,
)
@staticmethod
def _download_from_hf(repo_id: str) -> str:
"""Download model from Hugging Face Hub."""
try:
from huggingface_hub import snapshot_download
return snapshot_download(repo_id)
except ImportError:
raise ImportError(
"huggingface_hub is required to download from HF. "
"Install with: pip install huggingface_hub"
)
def embed(self, queries: Union[str, List[str]]) -> np.ndarray:
"""
Embed queries using Qwen3-0.6B.
If embed_url is set, uses a remote vLLM server (OpenAI-compatible API).
Otherwise, loads Qwen3-0.6B locally via vLLM (on first call).
Args:
queries: Single query string or list of queries
Returns:
numpy array of shape (N, 1024)
"""
if isinstance(queries, str):
queries = [queries]
if self.embed_url:
return self._embed_remote(queries)
return self._embed_local(queries)
def _embed_remote(self, queries: List[str]) -> np.ndarray:
"""Embed via a running vLLM server (OpenAI-compatible embeddings API)."""
import urllib.request
url = self.embed_url.rstrip("/") + "/v1/embeddings"
payload = json.dumps({
"model": "Qwen/Qwen3-0.6B",
"input": queries,
}).encode()
req = urllib.request.Request(
url, data=payload,
headers={"Content-Type": "application/json"},
)
with urllib.request.urlopen(req) as resp:
result = json.loads(resp.read())
embeddings = [item["embedding"] for item in sorted(result["data"], key=lambda x: x["index"])]
return np.array(embeddings)
def _embed_local(self, queries: List[str]) -> np.ndarray:
"""Embed by loading Qwen3-0.6B locally via vLLM."""
if self._embedder is None:
try:
from vllm import LLM
except ImportError:
raise ImportError(
"vLLM is required for local embedding. "
"Install with: uv pip install vllm\n"
"Or start a vLLM server and pass embed_url to from_pretrained()."
)
self._embedder = LLM(
model="Qwen/Qwen3-0.6B",
runner="pooling",
trust_remote_code=True,
dtype="half",
)
outputs = self._embedder.embed(queries)
return np.array([o.outputs.embedding for o in outputs])
def route_text(
self,
query: Union[str, List[str]],
lambda_val: Optional[float] = None,
) -> Union[Dict, List[Dict]]:
"""
Route text query(ies) end-to-end: embed with Qwen3-0.6B, then route.
Args:
query: Single query string or list of queries
lambda_val: Override default lambda
Returns:
Routing decision dict (single) or list of dicts (batch)
"""
embeddings = self.embed(query)
if isinstance(query, str):
return self.route(embeddings[0], lambda_val)
return self.route_batch(embeddings, lambda_val)
def route(
self,
embedding: np.ndarray,
lambda_val: Optional[float] = None,
) -> Dict:
"""
Route a query to the optimal (model, token_budget) pair.
Args:
embedding: Query embedding vector, shape (1024,) or (1, 1024)
lambda_val: Override default lambda (higher = more cost-sensitive)
Returns:
Dict with keys: model, model_full_name, budget, token_limit,
predicted_quality, predicted_cost, risk, all_options
"""
if embedding.ndim == 1:
embedding = embedding.reshape(1, -1)
lam = lambda_val if lambda_val is not None else self.lambda_val
all_options = []
for mn in self.quality_knns:
price = self.model_prices.get(mn, 0)
if mn in self.token_knns:
tok = max(1.0, float(self.token_knns[mn].predict(embedding)[0]))
else:
tok = 50.0
for budget_name, knn in self.quality_knns[mn].items():
q = float(knn.predict(embedding)[0])
risk = (1 - lam) * q - lam * tok * price / 1e6
all_options.append({
"model": mn,
"model_full_name": self.model_names.get(mn, mn),
"budget": budget_name,
"token_limit": self.budgets.get(budget_name, budget_name),
"predicted_quality": q,
"predicted_tokens": tok,
"predicted_cost": tok * price / 1e6,
"risk": risk,
})
if not all_options:
raise RuntimeError("No valid routing options")
best = max(all_options, key=lambda x: x["risk"])
best["all_options"] = all_options
return best
def route_batch(
self,
embeddings: np.ndarray,
lambda_val: Optional[float] = None,
) -> List[Dict]:
"""
Route a batch of queries.
Args:
embeddings: Query embeddings, shape (N, 1024)
lambda_val: Override default lambda
Returns:
List of routing decisions
"""
return [self.route(embeddings[i], lambda_val) for i in range(len(embeddings))]