""" R2-Router: LLM Router with Joint Model-Budget Optimization Self-contained inference module. Routes queries to the optimal (model, token_budget) pair by predicting per-query quality and cost using KNN. Usage: from router import R2Router # Option A: Local vLLM (loads Qwen3-0.6B on first call) router = R2Router.from_pretrained(path) result = router.route_text("What is the capital of France?") # Option B: Remote vLLM server (no local GPU needed for embedding) # Start server: vllm serve Qwen/Qwen3-0.6B --runner pooling router = R2Router.from_pretrained(path, embed_url="http://localhost:8000") result = router.route_text("What is the capital of France?") # Option C: Pre-computed embedding result = router.route(embedding) # np.ndarray (1024,) """ import os import json import numpy as np import joblib from typing import Dict, List, Optional, Union from sklearn.neighbors import KNeighborsRegressor class R2Router: """ R2-Router: Routes queries to optimal (LLM, token_budget) pair. Uses KNN to predict quality for each (model, budget) combination, then selects the pair that maximizes: risk = (1 - lambda) * quality - lambda * tokens * price / 1e6 """ def __init__( self, quality_knns: Dict[str, Dict[str, KNeighborsRegressor]], token_knns: Dict[str, KNeighborsRegressor], model_prices: Dict[str, float], model_names: Dict[str, str], budgets: Dict[str, int], lambda_val: float = 0.999, embed_url: Optional[str] = None, ): self.quality_knns = quality_knns # {model: {budget: KNN}} self.token_knns = token_knns # {model: KNN} self.model_prices = model_prices # {model: price_per_million_output_tokens} self.model_names = model_names # {short_name: full_name} self.budgets = budgets # {budget_name: token_limit} self.lambda_val = lambda_val self.embed_url = embed_url # vLLM server URL, e.g. "http://localhost:8000" self._embedder = None @classmethod def from_pretrained( cls, path: str, lambda_val: float = 0.999, embed_url: Optional[str] = None, ) -> "R2Router": """ Load pre-trained KNN checkpoints. Args: path: Local directory or HuggingFace repo ID (e.g., "JiaqiXue/r2-router") lambda_val: Cost-accuracy tradeoff (higher = more cost-sensitive) embed_url: vLLM server URL for embedding (e.g., "http://localhost:8000"). If None, loads Qwen3-0.6B locally on first route_text() call. """ if not os.path.isdir(path): path = cls._download_from_hf(path) with open(os.path.join(path, "config.json")) as f: config = json.load(f) ckpt_dir = os.path.join(path, "checkpoints") quality_knns = {} token_knns = {} for model_name in config["models"]: quality_knns[model_name] = {} for budget_name in config["budgets"]: ckpt_path = os.path.join(ckpt_dir, f"quality_knn_{model_name}_{budget_name}.joblib") if os.path.exists(ckpt_path): quality_knns[model_name][budget_name] = joblib.load(ckpt_path) tok_path = os.path.join(ckpt_dir, f"token_knn_{model_name}.joblib") if os.path.exists(tok_path): token_knns[model_name] = joblib.load(tok_path) model_prices = { mn: cfg["output_price_per_million"] for mn, cfg in config["models"].items() } model_names = { mn: cfg["full_name"] for mn, cfg in config["models"].items() } return cls( quality_knns=quality_knns, token_knns=token_knns, model_prices=model_prices, model_names=model_names, budgets=config["budgets"], lambda_val=lambda_val, embed_url=embed_url, ) @classmethod def from_training_data( cls, path: str, k: int = 80, lambda_val: float = 0.999, ) -> "R2Router": """ Train KNN from scratch using the provided training data. Args: path: Local directory or HuggingFace repo ID k: Number of KNN neighbors lambda_val: Cost-accuracy tradeoff """ if not os.path.isdir(path): path = cls._download_from_hf(path) with open(os.path.join(path, "config.json")) as f: config = json.load(f) X_train = np.load(os.path.join(path, "training_data", "embeddings.npy")) with open(os.path.join(path, "training_data", "labels.json")) as f: labels = json.load(f) print(f"Training router on {len(X_train)} samples (k={k})...") quality_knns = {} token_knns = {} n_quality = 0 n_token = 0 for model_name, model_labels in labels.items(): quality_knns[model_name] = {} for budget_name, bdata in model_labels.items(): acc = np.array([x if x is not None else np.nan for x in bdata["accuracy"]]) valid = ~np.isnan(acc) if valid.sum() < 3: continue knn = KNeighborsRegressor( n_neighbors=min(k, int(valid.sum()) - 1), metric="cosine", weights="distance", ) knn.fit(X_train[valid], acc[valid]) quality_knns[model_name][budget_name] = knn n_quality += 1 if "concise" in model_labels and "output_tokens" in model_labels["concise"]: tok = np.array([x if x is not None else np.nan for x in model_labels["concise"]["output_tokens"]]) valid = ~np.isnan(tok) if valid.sum() >= 3: tknn = KNeighborsRegressor( n_neighbors=min(k, int(valid.sum()) - 1), metric="cosine", weights="distance", ) tknn.fit(X_train[valid], tok[valid]) token_knns[model_name] = tknn n_token += 1 print(f"Trained {n_quality} quality predictors + {n_token} token predictors for {len(quality_knns)} models.") model_prices = { mn: cfg["output_price_per_million"] for mn, cfg in config["models"].items() } model_names = { mn: cfg["full_name"] for mn, cfg in config["models"].items() } return cls( quality_knns=quality_knns, token_knns=token_knns, model_prices=model_prices, model_names=model_names, budgets=config["budgets"], lambda_val=lambda_val, ) @staticmethod def _download_from_hf(repo_id: str) -> str: """Download model from Hugging Face Hub.""" try: from huggingface_hub import snapshot_download return snapshot_download(repo_id) except ImportError: raise ImportError( "huggingface_hub is required to download from HF. " "Install with: pip install huggingface_hub" ) def embed(self, queries: Union[str, List[str]]) -> np.ndarray: """ Embed queries using Qwen3-0.6B. If embed_url is set, uses a remote vLLM server (OpenAI-compatible API). Otherwise, loads Qwen3-0.6B locally via vLLM (on first call). Args: queries: Single query string or list of queries Returns: numpy array of shape (N, 1024) """ if isinstance(queries, str): queries = [queries] if self.embed_url: return self._embed_remote(queries) return self._embed_local(queries) def _embed_remote(self, queries: List[str]) -> np.ndarray: """Embed via a running vLLM server (OpenAI-compatible embeddings API).""" import urllib.request url = self.embed_url.rstrip("/") + "/v1/embeddings" payload = json.dumps({ "model": "Qwen/Qwen3-0.6B", "input": queries, }).encode() req = urllib.request.Request( url, data=payload, headers={"Content-Type": "application/json"}, ) with urllib.request.urlopen(req) as resp: result = json.loads(resp.read()) embeddings = [item["embedding"] for item in sorted(result["data"], key=lambda x: x["index"])] return np.array(embeddings) def _embed_local(self, queries: List[str]) -> np.ndarray: """Embed by loading Qwen3-0.6B locally via vLLM.""" if self._embedder is None: try: from vllm import LLM except ImportError: raise ImportError( "vLLM is required for local embedding. " "Install with: uv pip install vllm\n" "Or start a vLLM server and pass embed_url to from_pretrained()." ) self._embedder = LLM( model="Qwen/Qwen3-0.6B", runner="pooling", trust_remote_code=True, dtype="half", ) outputs = self._embedder.embed(queries) return np.array([o.outputs.embedding for o in outputs]) def route_text( self, query: Union[str, List[str]], lambda_val: Optional[float] = None, ) -> Union[Dict, List[Dict]]: """ Route text query(ies) end-to-end: embed with Qwen3-0.6B, then route. Args: query: Single query string or list of queries lambda_val: Override default lambda Returns: Routing decision dict (single) or list of dicts (batch) """ embeddings = self.embed(query) if isinstance(query, str): return self.route(embeddings[0], lambda_val) return self.route_batch(embeddings, lambda_val) def route( self, embedding: np.ndarray, lambda_val: Optional[float] = None, ) -> Dict: """ Route a query to the optimal (model, token_budget) pair. Args: embedding: Query embedding vector, shape (1024,) or (1, 1024) lambda_val: Override default lambda (higher = more cost-sensitive) Returns: Dict with keys: model, model_full_name, budget, token_limit, predicted_quality, predicted_cost, risk, all_options """ if embedding.ndim == 1: embedding = embedding.reshape(1, -1) lam = lambda_val if lambda_val is not None else self.lambda_val all_options = [] for mn in self.quality_knns: price = self.model_prices.get(mn, 0) if mn in self.token_knns: tok = max(1.0, float(self.token_knns[mn].predict(embedding)[0])) else: tok = 50.0 for budget_name, knn in self.quality_knns[mn].items(): q = float(knn.predict(embedding)[0]) risk = (1 - lam) * q - lam * tok * price / 1e6 all_options.append({ "model": mn, "model_full_name": self.model_names.get(mn, mn), "budget": budget_name, "token_limit": self.budgets.get(budget_name, budget_name), "predicted_quality": q, "predicted_tokens": tok, "predicted_cost": tok * price / 1e6, "risk": risk, }) if not all_options: raise RuntimeError("No valid routing options") best = max(all_options, key=lambda x: x["risk"]) best["all_options"] = all_options return best def route_batch( self, embeddings: np.ndarray, lambda_val: Optional[float] = None, ) -> List[Dict]: """ Route a batch of queries. Args: embeddings: Query embeddings, shape (N, 1024) lambda_val: Override default lambda Returns: List of routing decisions """ return [self.route(embeddings[i], lambda_val) for i in range(len(embeddings))]