JiaqiXue commited on
Commit
52e5f26
·
verified ·
1 Parent(s): 5309389

Add vLLM server mode: start once, route from anywhere

Browse files
Files changed (2) hide show
  1. README.md +30 -7
  2. router.py +53 -12
README.md CHANGED
@@ -79,14 +79,36 @@ for q, r in zip(queries, results):
79
  print(f"{q[:40]:40s} -> {r['model']} (budget={r['token_limit']})")
80
  ```
81
 
82
- ### CPU-Only (No GPU)
83
 
84
- If you don't have a GPU, provide pre-computed embeddings directly:
 
 
 
 
 
 
85
 
86
  ```python
87
- import numpy as np
 
 
 
 
 
 
88
  from router import R2Router
89
 
 
 
 
 
 
 
 
 
 
 
90
  router = R2Router.from_pretrained(path)
91
 
92
  # Your own 1024-dim embedding (e.g., from an API or pre-computed)
@@ -175,13 +197,14 @@ checkpoints/
175
  token_knn_*.joblib # Pre-fitted KNN token predictors (6 total)
176
  ```
177
 
178
- ### Three Ways to Use
179
 
180
  | Method | GPU? | Description |
181
  |--------|------|-------------|
182
- | `router.route_text(query)` | Yes | End-to-end: auto-embeds with vLLM, then routes |
183
- | `router.route(embedding)` | No | Route from pre-computed 1024-dim embedding |
184
- | `R2Router.from_training_data(path)` | No | Train your own KNN with custom hyperparameters |
 
185
 
186
  ## Training Details
187
 
 
79
  print(f"{q[:40]:40s} -> {r['model']} (budget={r['token_limit']})")
80
  ```
81
 
82
+ ### With vLLM Server (Recommended for Production)
83
 
84
+ Start the embedding server once, then route from any process without reloading the model:
85
+
86
+ ```bash
87
+ # Terminal 1: Start vLLM embedding server (runs once, stays alive)
88
+ uv pip install vllm
89
+ vllm serve Qwen/Qwen3-0.6B --task embed --port 8000
90
+ ```
91
 
92
  ```python
93
+ # Terminal 2: Route queries (connects to the running server)
94
+ from huggingface_hub import snapshot_download
95
+ import sys
96
+
97
+ path = snapshot_download("JiaqiXue/r2-router")
98
+ sys.path.insert(0, path)
99
+
100
  from router import R2Router
101
 
102
+ router = R2Router.from_pretrained(path, embed_url="http://localhost:8000")
103
+ result = router.route_text("What is the capital of France?")
104
+ print(f"Model: {result['model_full_name']}, Budget: {result['token_limit']}")
105
+ ```
106
+
107
+ ### CPU-Only (No GPU)
108
+
109
+ If you don't have a GPU, provide pre-computed embeddings directly:
110
+
111
+ ```python
112
  router = R2Router.from_pretrained(path)
113
 
114
  # Your own 1024-dim embedding (e.g., from an API or pre-computed)
 
197
  token_knn_*.joblib # Pre-fitted KNN token predictors (6 total)
198
  ```
199
 
200
+ ### Ways to Use
201
 
202
  | Method | GPU? | Description |
203
  |--------|------|-------------|
204
+ | `route_text()` + vLLM server | Yes (server) | Start `vllm serve` once, route from anywhere via HTTP |
205
+ | `route_text()` + local vLLM | Yes (local) | Auto-loads Qwen3-0.6B on first call, caches it |
206
+ | `route(embedding)` | No | Route from pre-computed 1024-dim embedding |
207
+ | `from_training_data(path)` | No | Train your own KNN with custom hyperparameters |
208
 
209
  ## Training Details
210
 
router.py CHANGED
@@ -7,16 +7,17 @@ pair by predicting per-query quality and cost using KNN.
7
  Usage:
8
  from router import R2Router
9
 
 
10
  router = R2Router.from_pretrained(path)
 
11
 
12
- # Option 1: Route from text (auto-embeds with vLLM)
 
 
13
  result = router.route_text("What is the capital of France?")
14
 
15
- # Option 2: Route from pre-computed embedding
16
  result = router.route(embedding) # np.ndarray (1024,)
17
-
18
- # Option 3: Train from scratch
19
- router = R2Router.from_training_data(path, k=80)
20
  """
21
 
22
  import os
@@ -44,6 +45,7 @@ class R2Router:
44
  model_names: Dict[str, str],
45
  budgets: Dict[str, int],
46
  lambda_val: float = 0.999,
 
47
  ):
48
  self.quality_knns = quality_knns # {model: {budget: KNN}}
49
  self.token_knns = token_knns # {model: KNN}
@@ -51,16 +53,24 @@ class R2Router:
51
  self.model_names = model_names # {short_name: full_name}
52
  self.budgets = budgets # {budget_name: token_limit}
53
  self.lambda_val = lambda_val
 
54
  self._embedder = None
55
 
56
  @classmethod
57
- def from_pretrained(cls, path: str, lambda_val: float = 0.999) -> "R2Router":
 
 
 
 
 
58
  """
59
  Load pre-trained KNN checkpoints.
60
 
61
  Args:
62
  path: Local directory or HuggingFace repo ID (e.g., "JiaqiXue/r2-router")
63
  lambda_val: Cost-accuracy tradeoff (higher = more cost-sensitive)
 
 
64
  """
65
  if not os.path.isdir(path):
66
  path = cls._download_from_hf(path)
@@ -99,6 +109,7 @@ class R2Router:
99
  model_names=model_names,
100
  budgets=config["budgets"],
101
  lambda_val=lambda_val,
 
102
  )
103
 
104
  @classmethod
@@ -188,7 +199,10 @@ class R2Router:
188
 
189
  def embed(self, queries: Union[str, List[str]]) -> np.ndarray:
190
  """
191
- Embed queries using Qwen3-0.6B via vLLM (loaded on first call).
 
 
 
192
 
193
  Args:
194
  queries: Single query string or list of queries
@@ -196,13 +210,43 @@ class R2Router:
196
  Returns:
197
  numpy array of shape (N, 1024)
198
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  if self._embedder is None:
200
  try:
201
  from vllm import LLM
202
  except ImportError:
203
  raise ImportError(
204
- "vLLM is required for text embedding. "
205
- "Install with: uv pip install vllm"
 
206
  )
207
  self._embedder = LLM(
208
  model="Qwen/Qwen3-0.6B",
@@ -211,9 +255,6 @@ class R2Router:
211
  dtype="half",
212
  )
213
 
214
- if isinstance(queries, str):
215
- queries = [queries]
216
-
217
  outputs = self._embedder.embed(queries)
218
  return np.array([o.outputs.embedding for o in outputs])
219
 
 
7
  Usage:
8
  from router import R2Router
9
 
10
+ # Option A: Local vLLM (loads Qwen3-0.6B on first call)
11
  router = R2Router.from_pretrained(path)
12
+ result = router.route_text("What is the capital of France?")
13
 
14
+ # Option B: Remote vLLM server (no local GPU needed for embedding)
15
+ # Start server: vllm serve Qwen/Qwen3-0.6B --task embed
16
+ router = R2Router.from_pretrained(path, embed_url="http://localhost:8000")
17
  result = router.route_text("What is the capital of France?")
18
 
19
+ # Option C: Pre-computed embedding
20
  result = router.route(embedding) # np.ndarray (1024,)
 
 
 
21
  """
22
 
23
  import os
 
45
  model_names: Dict[str, str],
46
  budgets: Dict[str, int],
47
  lambda_val: float = 0.999,
48
+ embed_url: Optional[str] = None,
49
  ):
50
  self.quality_knns = quality_knns # {model: {budget: KNN}}
51
  self.token_knns = token_knns # {model: KNN}
 
53
  self.model_names = model_names # {short_name: full_name}
54
  self.budgets = budgets # {budget_name: token_limit}
55
  self.lambda_val = lambda_val
56
+ self.embed_url = embed_url # vLLM server URL, e.g. "http://localhost:8000"
57
  self._embedder = None
58
 
59
  @classmethod
60
+ def from_pretrained(
61
+ cls,
62
+ path: str,
63
+ lambda_val: float = 0.999,
64
+ embed_url: Optional[str] = None,
65
+ ) -> "R2Router":
66
  """
67
  Load pre-trained KNN checkpoints.
68
 
69
  Args:
70
  path: Local directory or HuggingFace repo ID (e.g., "JiaqiXue/r2-router")
71
  lambda_val: Cost-accuracy tradeoff (higher = more cost-sensitive)
72
+ embed_url: vLLM server URL for embedding (e.g., "http://localhost:8000").
73
+ If None, loads Qwen3-0.6B locally on first route_text() call.
74
  """
75
  if not os.path.isdir(path):
76
  path = cls._download_from_hf(path)
 
109
  model_names=model_names,
110
  budgets=config["budgets"],
111
  lambda_val=lambda_val,
112
+ embed_url=embed_url,
113
  )
114
 
115
  @classmethod
 
199
 
200
  def embed(self, queries: Union[str, List[str]]) -> np.ndarray:
201
  """
202
+ Embed queries using Qwen3-0.6B.
203
+
204
+ If embed_url is set, uses a remote vLLM server (OpenAI-compatible API).
205
+ Otherwise, loads Qwen3-0.6B locally via vLLM (on first call).
206
 
207
  Args:
208
  queries: Single query string or list of queries
 
210
  Returns:
211
  numpy array of shape (N, 1024)
212
  """
213
+ if isinstance(queries, str):
214
+ queries = [queries]
215
+
216
+ if self.embed_url:
217
+ return self._embed_remote(queries)
218
+ return self._embed_local(queries)
219
+
220
+ def _embed_remote(self, queries: List[str]) -> np.ndarray:
221
+ """Embed via a running vLLM server (OpenAI-compatible embeddings API)."""
222
+ import urllib.request
223
+
224
+ url = self.embed_url.rstrip("/") + "/v1/embeddings"
225
+ payload = json.dumps({
226
+ "model": "Qwen/Qwen3-0.6B",
227
+ "input": queries,
228
+ }).encode()
229
+
230
+ req = urllib.request.Request(
231
+ url, data=payload,
232
+ headers={"Content-Type": "application/json"},
233
+ )
234
+ with urllib.request.urlopen(req) as resp:
235
+ result = json.loads(resp.read())
236
+
237
+ embeddings = [item["embedding"] for item in sorted(result["data"], key=lambda x: x["index"])]
238
+ return np.array(embeddings)
239
+
240
+ def _embed_local(self, queries: List[str]) -> np.ndarray:
241
+ """Embed by loading Qwen3-0.6B locally via vLLM."""
242
  if self._embedder is None:
243
  try:
244
  from vllm import LLM
245
  except ImportError:
246
  raise ImportError(
247
+ "vLLM is required for local embedding. "
248
+ "Install with: uv pip install vllm\n"
249
+ "Or start a vLLM server and pass embed_url to from_pretrained()."
250
  )
251
  self._embedder = LLM(
252
  model="Qwen/Qwen3-0.6B",
 
255
  dtype="half",
256
  )
257
 
 
 
 
258
  outputs = self._embedder.embed(queries)
259
  return np.array([o.outputs.embedding for o in outputs])
260