| """ |
| OpenAI-compatible HTTP server for DFlash speculative decoding on MLX. |
| |
| Supports: |
| - POST /v1/chat/completions (with streaming via SSE) |
| - POST /v1/completions |
| - GET /v1/models |
| - GET /health |
| - GET /metrics (DFlash-specific diagnostics) |
| |
| Inspired by bstnxbt/dflash-mlx server architecture and Aryagm's OpenAI server. |
| """ |
|
|
| import json |
| import time |
| from typing import Any, Dict, List, Optional |
|
|
| from .speculative_decode import DFlashSpeculativeDecoder |
| from .adapters import load_target_model, LoadedTargetModel |
| from .convert import load_mlx_dflash |
|
|
|
|
| class DFlashServer: |
| """OpenAI-compatible server wrapping a DFlashSpeculativeDecoder.""" |
| |
| def __init__( |
| self, |
| target_model_path: str, |
| draft_model_path: Optional[str] = None, |
| block_size: int = 16, |
| device: str = "metal", |
| ): |
| """Initialize server with target and optional draft model. |
| |
| Args: |
| target_model_path: Path or HF ID of MLX target model |
| draft_model_path: Path or HF ID of converted DFlash drafter |
| block_size: Draft block size |
| device: MLX device |
| """ |
| print(f"[Server] Loading target model: {target_model_path}...") |
| self.loaded_target = load_target_model(target_model_path) |
| |
| if draft_model_path: |
| print(f"[Server] Loading DFlash drafter: {draft_model_path}...") |
| self.draft_model, self.draft_config = load_mlx_dflash(draft_model_path) |
| else: |
| |
| from .convert import _infer_target_model |
| inferred = _infer_target_model(target_model_path) |
| if inferred and inferred != "unknown": |
| print(f"[Server] Auto-resolved drafter: {inferred}") |
| |
| self.draft_model, self.draft_config = None, None |
| else: |
| print("[Server] No draft model — will use baseline generation") |
| self.draft_model, self.draft_config = None, None |
| |
| if self.draft_model is not None: |
| self.decoder = DFlashSpeculativeDecoder( |
| target_model=self.loaded_target, |
| draft_model=self.draft_model, |
| tokenizer=self.loaded_target.tokenizer, |
| block_size=block_size, |
| device=device, |
| ) |
| self.mode = "dflash" |
| else: |
| self.decoder = None |
| self.mode = "baseline" |
| |
| |
| self.request_count = 0 |
| self.total_tokens = 0 |
| self.total_time = 0.0 |
| self.recent_requests: List[Dict] = [] |
| |
| def health(self) -> Dict[str, Any]: |
| return {"status": "ok", "mode": self.mode, "model": self.loaded_target.requested_model} |
| |
| def models(self) -> Dict[str, Any]: |
| return { |
| "object": "list", |
| "data": [{ |
| "id": self.loaded_target.requested_model, |
| "object": "model", |
| "owned_by": "dflash-mlx-universal", |
| }] |
| } |
| |
| def metrics(self) -> Dict[str, Any]: |
| avg_tok_s = self.total_tokens / self.total_time if self.total_time > 0 else 0 |
| return { |
| "request_count": self.request_count, |
| "total_tokens": self.total_tokens, |
| "avg_tokens_per_sec": avg_tok_s, |
| "recent_requests": self.recent_requests[-32:], |
| "mode": self.mode, |
| } |
| |
| def _update_metrics(self, num_tokens: int, elapsed: float): |
| self.request_count += 1 |
| self.total_tokens += num_tokens |
| self.total_time += elapsed |
| self.recent_requests.append({ |
| "timestamp": time.time(), |
| "tokens": num_tokens, |
| "time_sec": elapsed, |
| "tok_s": num_tokens / elapsed if elapsed > 0 else 0, |
| }) |
| if len(self.recent_requests) > 32: |
| self.recent_requests = self.recent_requests[-32:] |
| |
| def chat_completions( |
| self, |
| messages: List[Dict[str, str]], |
| max_tokens: int = 1024, |
| temperature: float = 0.0, |
| stream: bool = False, |
| stop: Optional[List[str]] = None, |
| ) -> Dict[str, Any] | Any: |
| """Handle chat completion request. |
| |
| Returns dict for non-streaming, generator for streaming. |
| """ |
| |
| prompt = self._messages_to_prompt(messages) |
| |
| if stream: |
| return self._stream_chat(prompt, max_tokens, temperature, stop) |
| |
| |
| start = time.time() |
| |
| if self.mode == "dflash" and self.decoder is not None: |
| output = self.decoder.generate( |
| prompt=prompt, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| stop_strings=stop, |
| ) |
| else: |
| |
| from mlx_lm.utils import generate as mlx_generate |
| output = mlx_generate( |
| model=self.loaded_target.model, |
| tokenizer=self.loaded_target.tokenizer, |
| prompt=prompt, |
| max_tokens=max_tokens, |
| temp=temperature, |
| ) |
| |
| elapsed = time.time() - start |
| num_tokens = len(self.loaded_target.tokenizer.encode(output)) |
| self._update_metrics(num_tokens, elapsed) |
| |
| return { |
| "id": f"chatcmpl-{int(time.time()*1000)}", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": self.loaded_target.requested_model, |
| "choices": [{ |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": output, |
| }, |
| "finish_reason": "stop", |
| }], |
| "usage": { |
| "prompt_tokens": len(self.loaded_target.tokenizer.encode(prompt)), |
| "completion_tokens": num_tokens, |
| "total_tokens": len(self.loaded_target.tokenizer.encode(prompt)) + num_tokens, |
| } |
| } |
| |
| def _stream_chat(self, prompt: str, max_tokens: int, temperature: float, stop): |
| """Generator for streaming SSE chunks.""" |
| |
| def event(data: Dict) -> str: |
| return f"data: {json.dumps(data)}\n\n" |
| |
| |
| yield event({ |
| "id": f"chatcmpl-{int(time.time()*1000)}", |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": self.loaded_target.requested_model, |
| "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], |
| }) |
| |
| accumulated = "" |
| |
| if self.mode == "dflash" and self.decoder is not None: |
| |
| for chunk in self.decoder.generate( |
| prompt=prompt, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| stop_strings=stop, |
| stream=True, |
| ): |
| accumulated += chunk |
| yield event({ |
| "id": f"chatcmpl-{int(time.time()*1000)}", |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": self.loaded_target.requested_model, |
| "choices": [{"index": 0, "delta": {"content": chunk}, "finish_reason": None}], |
| }) |
| else: |
| |
| from mlx_lm.utils import generate as mlx_generate |
| output = mlx_generate( |
| model=self.loaded_target.model, |
| tokenizer=self.loaded_target.tokenizer, |
| prompt=prompt, |
| max_tokens=max_tokens, |
| temp=temperature, |
| ) |
| for word in output.split(" "): |
| chunk = word + " " |
| accumulated += chunk |
| yield event({ |
| "id": f"chatcmpl-{int(time.time()*1000)}", |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": self.loaded_target.requested_model, |
| "choices": [{"index": 0, "delta": {"content": chunk}, "finish_reason": None}], |
| }) |
| |
| |
| yield event({ |
| "id": f"chatcmpl-{int(time.time()*1000)}", |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": self.loaded_target.requested_model, |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], |
| }) |
| yield "data: [DONE]\n\n" |
| |
| def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: |
| """Convert OpenAI messages format to prompt string.""" |
| |
| tokenizer = self.loaded_target.tokenizer |
| if hasattr(tokenizer, "apply_chat_template"): |
| try: |
| return tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| except Exception: |
| pass |
| |
| |
| prompt = "" |
| for msg in messages: |
| role = msg.get("role", "user") |
| content = msg.get("content", "") |
| if role == "system": |
| prompt += f"System: {content}\n" |
| elif role == "user": |
| prompt += f"User: {content}\n" |
| elif role == "assistant": |
| prompt += f"Assistant: {content}\n" |
| prompt += "Assistant: " |
| return prompt |
|
|
|
|
| def create_app(target_model: str, draft_model: Optional[str] = None, block_size: int = 16): |
| """Create a Flask/FastAPI-style app for serving.""" |
| try: |
| from fastapi import FastAPI, Request |
| from fastapi.responses import StreamingResponse |
| |
| app = FastAPI(title="DFlash MLX Server") |
| server = DFlashServer(target_model, draft_model, block_size) |
| |
| @app.get("/health") |
| async def health(): |
| return server.health() |
| |
| @app.get("/v1/models") |
| async def models(): |
| return server.models() |
| |
| @app.get("/metrics") |
| async def metrics(): |
| return server.metrics() |
| |
| @app.post("/v1/chat/completions") |
| async def chat_completions(request: Request): |
| body = await request.json() |
| messages = body.get("messages", []) |
| max_tokens = body.get("max_tokens", 1024) |
| temperature = body.get("temperature", 0.0) |
| stream = body.get("stream", False) |
| stop = body.get("stop", None) |
| |
| result = server.chat_completions( |
| messages=messages, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| stream=stream, |
| stop=stop, |
| ) |
| |
| if stream: |
| return StreamingResponse(result, media_type="text/event-stream") |
| return result |
| |
| return app |
| |
| except ImportError: |
| print("[Server] FastAPI not installed. Install with: pip install fastapi uvicorn") |
| |
| |
| from http.server import BaseHTTPRequestHandler, HTTPServer |
| import threading |
| |
| class Handler(BaseHTTPRequestHandler): |
| server_instance = None |
| |
| def do_GET(self): |
| if self.path == "/health": |
| self._json_response(200, self.server_instance.health()) |
| elif self.path == "/v1/models": |
| self._json_response(200, self.server_instance.models()) |
| elif self.path == "/metrics": |
| self._json_response(200, self.server_instance.metrics()) |
| else: |
| self._json_response(404, {"error": "Not found"}) |
| |
| def do_POST(self): |
| if self.path == "/v1/chat/completions": |
| content_len = int(self.headers.get("Content-Length", 0)) |
| body = json.loads(self.rfile.read(content_len)) |
| result = self.server_instance.chat_completions( |
| messages=body.get("messages", []), |
| max_tokens=body.get("max_tokens", 1024), |
| temperature=body.get("temperature", 0.0), |
| stream=False, |
| stop=body.get("stop", None), |
| ) |
| self._json_response(200, result) |
| else: |
| self._json_response(404, {"error": "Not found"}) |
| |
| def _json_response(self, status: int, data: Dict): |
| self.send_response(status) |
| self.send_header("Content-Type", "application/json") |
| self.end_headers() |
| self.wfile.write(json.dumps(data).encode()) |
| |
| Handler.server_instance = DFlashServer(target_model, draft_model, block_size) |
| return Handler |
|
|
|
|
| def main(): |
| import argparse |
| parser = argparse.ArgumentParser(description="DFlash MLX OpenAI-compatible server") |
| parser.add_argument("--target", required=True, help="Target model path or HF ID") |
| parser.add_argument("--draft", default=None, help="Draft model path or HF ID") |
| parser.add_argument("--block-size", type=int, default=16) |
| parser.add_argument("--host", default="127.0.0.1") |
| parser.add_argument("--port", type=int, default=8000) |
| parser.add_argument("--device", default="metal") |
| args = parser.parse_args() |
| |
| server = DFlashServer(args.target, args.draft, args.block_size, args.device) |
| |
| try: |
| import uvicorn |
| from fastapi import FastAPI, Request |
| from fastapi.responses import StreamingResponse |
| |
| app = FastAPI() |
| |
| @app.get("/health") |
| async def health(): |
| return server.health() |
| |
| @app.get("/v1/models") |
| async def models(): |
| return server.models() |
| |
| @app.get("/metrics") |
| async def metrics(): |
| return server.metrics() |
| |
| @app.post("/v1/chat/completions") |
| async def chat_completions(request: Request): |
| body = await request.json() |
| result = server.chat_completions( |
| messages=body.get("messages", []), |
| max_tokens=body.get("max_tokens", 1024), |
| temperature=body.get("temperature", 0.0), |
| stream=body.get("stream", False), |
| stop=body.get("stop", None), |
| ) |
| if body.get("stream", False): |
| return StreamingResponse(result, media_type="text/event-stream") |
| return result |
| |
| print(f"[Server] Starting FastAPI on http://{args.host}:{args.port}") |
| uvicorn.run(app, host=args.host, port=args.port) |
| |
| except ImportError: |
| print("[Server] FastAPI/uvicorn not available, using simple HTTP server") |
| from http.server import HTTPServer |
| handler = create_app(args.target, args.draft, args.block_size) |
| httpd = HTTPServer((args.host, args.port), handler) |
| print(f"[Server] Starting simple HTTP on http://{args.host}:{args.port}") |
| httpd.serve_forever() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|