tritesh commited on
Commit
834cedc
·
verified ·
1 Parent(s): 5fdf20a

Upload dflash_mlx/serve.py

Browse files
Files changed (1) hide show
  1. dflash_mlx/serve.py +419 -0
dflash_mlx/serve.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI-compatible HTTP server for DFlash speculative decoding on MLX.
3
+
4
+ Supports:
5
+ - POST /v1/chat/completions (with streaming via SSE)
6
+ - POST /v1/completions
7
+ - GET /v1/models
8
+ - GET /health
9
+ - GET /metrics (DFlash-specific diagnostics)
10
+
11
+ Inspired by bstnxbt/dflash-mlx server architecture and Aryagm's OpenAI server.
12
+ """
13
+
14
+ import json
15
+ import time
16
+ from typing import Any, Dict, List, Optional
17
+
18
+ from .speculative_decode import DFlashSpeculativeDecoder
19
+ from .adapters import load_target_model, LoadedTargetModel
20
+ from .convert import load_mlx_dflash
21
+
22
+
23
+ class DFlashServer:
24
+ """OpenAI-compatible server wrapping a DFlashSpeculativeDecoder."""
25
+
26
+ def __init__(
27
+ self,
28
+ target_model_path: str,
29
+ draft_model_path: Optional[str] = None,
30
+ block_size: int = 16,
31
+ device: str = "metal",
32
+ ):
33
+ """Initialize server with target and optional draft model.
34
+
35
+ Args:
36
+ target_model_path: Path or HF ID of MLX target model
37
+ draft_model_path: Path or HF ID of converted DFlash drafter
38
+ block_size: Draft block size
39
+ device: MLX device
40
+ """
41
+ print(f"[Server] Loading target model: {target_model_path}...")
42
+ self.loaded_target = load_target_model(target_model_path)
43
+
44
+ if draft_model_path:
45
+ print(f"[Server] Loading DFlash drafter: {draft_model_path}...")
46
+ self.draft_model, self.draft_config = load_mlx_dflash(draft_model_path)
47
+ else:
48
+ # Try to auto-resolve draft model
49
+ from .convert import _infer_target_model
50
+ inferred = _infer_target_model(target_model_path)
51
+ if inferred and inferred != "unknown":
52
+ print(f"[Server] Auto-resolved drafter: {inferred}")
53
+ # Look up in registry...
54
+ self.draft_model, self.draft_config = None, None
55
+ else:
56
+ print("[Server] No draft model — will use baseline generation")
57
+ self.draft_model, self.draft_config = None, None
58
+
59
+ if self.draft_model is not None:
60
+ self.decoder = DFlashSpeculativeDecoder(
61
+ target_model=self.loaded_target,
62
+ draft_model=self.draft_model,
63
+ tokenizer=self.loaded_target.tokenizer,
64
+ block_size=block_size,
65
+ device=device,
66
+ )
67
+ self.mode = "dflash"
68
+ else:
69
+ self.decoder = None
70
+ self.mode = "baseline"
71
+
72
+ # Metrics
73
+ self.request_count = 0
74
+ self.total_tokens = 0
75
+ self.total_time = 0.0
76
+ self.recent_requests: List[Dict] = []
77
+
78
+ def health(self) -> Dict[str, Any]:
79
+ return {"status": "ok", "mode": self.mode, "model": self.loaded_target.requested_model}
80
+
81
+ def models(self) -> Dict[str, Any]:
82
+ return {
83
+ "object": "list",
84
+ "data": [{
85
+ "id": self.loaded_target.requested_model,
86
+ "object": "model",
87
+ "owned_by": "dflash-mlx-universal",
88
+ }]
89
+ }
90
+
91
+ def metrics(self) -> Dict[str, Any]:
92
+ avg_tok_s = self.total_tokens / self.total_time if self.total_time > 0 else 0
93
+ return {
94
+ "request_count": self.request_count,
95
+ "total_tokens": self.total_tokens,
96
+ "avg_tokens_per_sec": avg_tok_s,
97
+ "recent_requests": self.recent_requests[-32:],
98
+ "mode": self.mode,
99
+ }
100
+
101
+ def _update_metrics(self, num_tokens: int, elapsed: float):
102
+ self.request_count += 1
103
+ self.total_tokens += num_tokens
104
+ self.total_time += elapsed
105
+ self.recent_requests.append({
106
+ "timestamp": time.time(),
107
+ "tokens": num_tokens,
108
+ "time_sec": elapsed,
109
+ "tok_s": num_tokens / elapsed if elapsed > 0 else 0,
110
+ })
111
+ if len(self.recent_requests) > 32:
112
+ self.recent_requests = self.recent_requests[-32:]
113
+
114
+ def chat_completions(
115
+ self,
116
+ messages: List[Dict[str, str]],
117
+ max_tokens: int = 1024,
118
+ temperature: float = 0.0,
119
+ stream: bool = False,
120
+ stop: Optional[List[str]] = None,
121
+ ) -> Dict[str, Any] | Any:
122
+ """Handle chat completion request.
123
+
124
+ Returns dict for non-streaming, generator for streaming.
125
+ """
126
+ # Build prompt from messages
127
+ prompt = self._messages_to_prompt(messages)
128
+
129
+ if stream:
130
+ return self._stream_chat(prompt, max_tokens, temperature, stop)
131
+
132
+ # Non-streaming
133
+ start = time.time()
134
+
135
+ if self.mode == "dflash" and self.decoder is not None:
136
+ output = self.decoder.generate(
137
+ prompt=prompt,
138
+ max_tokens=max_tokens,
139
+ temperature=temperature,
140
+ stop_strings=stop,
141
+ )
142
+ else:
143
+ # Baseline mlx_lm generation
144
+ from mlx_lm.utils import generate as mlx_generate
145
+ output = mlx_generate(
146
+ model=self.loaded_target.model,
147
+ tokenizer=self.loaded_target.tokenizer,
148
+ prompt=prompt,
149
+ max_tokens=max_tokens,
150
+ temp=temperature,
151
+ )
152
+
153
+ elapsed = time.time() - start
154
+ num_tokens = len(self.loaded_target.tokenizer.encode(output))
155
+ self._update_metrics(num_tokens, elapsed)
156
+
157
+ return {
158
+ "id": f"chatcmpl-{int(time.time()*1000)}",
159
+ "object": "chat.completion",
160
+ "created": int(time.time()),
161
+ "model": self.loaded_target.requested_model,
162
+ "choices": [{
163
+ "index": 0,
164
+ "message": {
165
+ "role": "assistant",
166
+ "content": output,
167
+ },
168
+ "finish_reason": "stop",
169
+ }],
170
+ "usage": {
171
+ "prompt_tokens": len(self.loaded_target.tokenizer.encode(prompt)),
172
+ "completion_tokens": num_tokens,
173
+ "total_tokens": len(self.loaded_target.tokenizer.encode(prompt)) + num_tokens,
174
+ }
175
+ }
176
+
177
+ def _stream_chat(self, prompt: str, max_tokens: int, temperature: float, stop):
178
+ """Generator for streaming SSE chunks."""
179
+
180
+ def event(data: Dict) -> str:
181
+ return f"data: {json.dumps(data)}\n\n"
182
+
183
+ # Yield initial role
184
+ yield event({
185
+ "id": f"chatcmpl-{int(time.time()*1000)}",
186
+ "object": "chat.completion.chunk",
187
+ "created": int(time.time()),
188
+ "model": self.loaded_target.requested_model,
189
+ "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
190
+ })
191
+
192
+ accumulated = ""
193
+
194
+ if self.mode == "dflash" and self.decoder is not None:
195
+ # Use streaming generate
196
+ for chunk in self.decoder.generate(
197
+ prompt=prompt,
198
+ max_tokens=max_tokens,
199
+ temperature=temperature,
200
+ stop_strings=stop,
201
+ stream=True,
202
+ ):
203
+ accumulated += chunk
204
+ yield event({
205
+ "id": f"chatcmpl-{int(time.time()*1000)}",
206
+ "object": "chat.completion.chunk",
207
+ "created": int(time.time()),
208
+ "model": self.loaded_target.requested_model,
209
+ "choices": [{"index": 0, "delta": {"content": chunk}, "finish_reason": None}],
210
+ })
211
+ else:
212
+ # Baseline: generate then stream word-by-word (not true streaming)
213
+ from mlx_lm.utils import generate as mlx_generate
214
+ output = mlx_generate(
215
+ model=self.loaded_target.model,
216
+ tokenizer=self.loaded_target.tokenizer,
217
+ prompt=prompt,
218
+ max_tokens=max_tokens,
219
+ temp=temperature,
220
+ )
221
+ for word in output.split(" "):
222
+ chunk = word + " "
223
+ accumulated += chunk
224
+ yield event({
225
+ "id": f"chatcmpl-{int(time.time()*1000)}",
226
+ "object": "chat.completion.chunk",
227
+ "created": int(time.time()),
228
+ "model": self.loaded_target.requested_model,
229
+ "choices": [{"index": 0, "delta": {"content": chunk}, "finish_reason": None}],
230
+ })
231
+
232
+ # Final chunk
233
+ yield event({
234
+ "id": f"chatcmpl-{int(time.time()*1000)}",
235
+ "object": "chat.completion.chunk",
236
+ "created": int(time.time()),
237
+ "model": self.loaded_target.requested_model,
238
+ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
239
+ })
240
+ yield "data: [DONE]\n\n"
241
+
242
+ def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
243
+ """Convert OpenAI messages format to prompt string."""
244
+ # Try chat template
245
+ tokenizer = self.loaded_target.tokenizer
246
+ if hasattr(tokenizer, "apply_chat_template"):
247
+ try:
248
+ return tokenizer.apply_chat_template(
249
+ messages,
250
+ tokenize=False,
251
+ add_generation_prompt=True,
252
+ )
253
+ except Exception:
254
+ pass
255
+
256
+ # Fallback: simple concatenation
257
+ prompt = ""
258
+ for msg in messages:
259
+ role = msg.get("role", "user")
260
+ content = msg.get("content", "")
261
+ if role == "system":
262
+ prompt += f"System: {content}\n"
263
+ elif role == "user":
264
+ prompt += f"User: {content}\n"
265
+ elif role == "assistant":
266
+ prompt += f"Assistant: {content}\n"
267
+ prompt += "Assistant: "
268
+ return prompt
269
+
270
+
271
+ def create_app(target_model: str, draft_model: Optional[str] = None, block_size: int = 16):
272
+ """Create a Flask/FastAPI-style app for serving."""
273
+ try:
274
+ from fastapi import FastAPI, Request
275
+ from fastapi.responses import StreamingResponse
276
+
277
+ app = FastAPI(title="DFlash MLX Server")
278
+ server = DFlashServer(target_model, draft_model, block_size)
279
+
280
+ @app.get("/health")
281
+ async def health():
282
+ return server.health()
283
+
284
+ @app.get("/v1/models")
285
+ async def models():
286
+ return server.models()
287
+
288
+ @app.get("/metrics")
289
+ async def metrics():
290
+ return server.metrics()
291
+
292
+ @app.post("/v1/chat/completions")
293
+ async def chat_completions(request: Request):
294
+ body = await request.json()
295
+ messages = body.get("messages", [])
296
+ max_tokens = body.get("max_tokens", 1024)
297
+ temperature = body.get("temperature", 0.0)
298
+ stream = body.get("stream", False)
299
+ stop = body.get("stop", None)
300
+
301
+ result = server.chat_completions(
302
+ messages=messages,
303
+ max_tokens=max_tokens,
304
+ temperature=temperature,
305
+ stream=stream,
306
+ stop=stop,
307
+ )
308
+
309
+ if stream:
310
+ return StreamingResponse(result, media_type="text/event-stream")
311
+ return result
312
+
313
+ return app
314
+
315
+ except ImportError:
316
+ print("[Server] FastAPI not installed. Install with: pip install fastapi uvicorn")
317
+
318
+ # Fallback: simple HTTP server
319
+ from http.server import BaseHTTPRequestHandler, HTTPServer
320
+ import threading
321
+
322
+ class Handler(BaseHTTPRequestHandler):
323
+ server_instance = None
324
+
325
+ def do_GET(self):
326
+ if self.path == "/health":
327
+ self._json_response(200, self.server_instance.health())
328
+ elif self.path == "/v1/models":
329
+ self._json_response(200, self.server_instance.models())
330
+ elif self.path == "/metrics":
331
+ self._json_response(200, self.server_instance.metrics())
332
+ else:
333
+ self._json_response(404, {"error": "Not found"})
334
+
335
+ def do_POST(self):
336
+ if self.path == "/v1/chat/completions":
337
+ content_len = int(self.headers.get("Content-Length", 0))
338
+ body = json.loads(self.rfile.read(content_len))
339
+ result = self.server_instance.chat_completions(
340
+ messages=body.get("messages", []),
341
+ max_tokens=body.get("max_tokens", 1024),
342
+ temperature=body.get("temperature", 0.0),
343
+ stream=False,
344
+ stop=body.get("stop", None),
345
+ )
346
+ self._json_response(200, result)
347
+ else:
348
+ self._json_response(404, {"error": "Not found"})
349
+
350
+ def _json_response(self, status: int, data: Dict):
351
+ self.send_response(status)
352
+ self.send_header("Content-Type", "application/json")
353
+ self.end_headers()
354
+ self.wfile.write(json.dumps(data).encode())
355
+
356
+ Handler.server_instance = DFlashServer(target_model, draft_model, block_size)
357
+ return Handler
358
+
359
+
360
+ def main():
361
+ import argparse
362
+ parser = argparse.ArgumentParser(description="DFlash MLX OpenAI-compatible server")
363
+ parser.add_argument("--target", required=True, help="Target model path or HF ID")
364
+ parser.add_argument("--draft", default=None, help="Draft model path or HF ID")
365
+ parser.add_argument("--block-size", type=int, default=16)
366
+ parser.add_argument("--host", default="127.0.0.1")
367
+ parser.add_argument("--port", type=int, default=8000)
368
+ parser.add_argument("--device", default="metal")
369
+ args = parser.parse_args()
370
+
371
+ server = DFlashServer(args.target, args.draft, args.block_size, args.device)
372
+
373
+ try:
374
+ import uvicorn
375
+ from fastapi import FastAPI, Request
376
+ from fastapi.responses import StreamingResponse
377
+
378
+ app = FastAPI()
379
+
380
+ @app.get("/health")
381
+ async def health():
382
+ return server.health()
383
+
384
+ @app.get("/v1/models")
385
+ async def models():
386
+ return server.models()
387
+
388
+ @app.get("/metrics")
389
+ async def metrics():
390
+ return server.metrics()
391
+
392
+ @app.post("/v1/chat/completions")
393
+ async def chat_completions(request: Request):
394
+ body = await request.json()
395
+ result = server.chat_completions(
396
+ messages=body.get("messages", []),
397
+ max_tokens=body.get("max_tokens", 1024),
398
+ temperature=body.get("temperature", 0.0),
399
+ stream=body.get("stream", False),
400
+ stop=body.get("stop", None),
401
+ )
402
+ if body.get("stream", False):
403
+ return StreamingResponse(result, media_type="text/event-stream")
404
+ return result
405
+
406
+ print(f"[Server] Starting FastAPI on http://{args.host}:{args.port}")
407
+ uvicorn.run(app, host=args.host, port=args.port)
408
+
409
+ except ImportError:
410
+ print("[Server] FastAPI/uvicorn not available, using simple HTTP server")
411
+ from http.server import HTTPServer
412
+ handler = create_app(args.target, args.draft, args.block_size)
413
+ httpd = HTTPServer((args.host, args.port), handler)
414
+ print(f"[Server] Starting simple HTTP on http://{args.host}:{args.port}")
415
+ httpd.serve_forever()
416
+
417
+
418
+ if __name__ == "__main__":
419
+ main()