""" OpenAI-compatible HTTP server for DFlash speculative decoding on MLX. Supports: - POST /v1/chat/completions (with streaming via SSE) - POST /v1/completions - GET /v1/models - GET /health - GET /metrics (DFlash-specific diagnostics) Inspired by bstnxbt/dflash-mlx server architecture and Aryagm's OpenAI server. """ import json import time from typing import Any, Dict, List, Optional from .speculative_decode import DFlashSpeculativeDecoder from .adapters import load_target_model, LoadedTargetModel from .convert import load_mlx_dflash class DFlashServer: """OpenAI-compatible server wrapping a DFlashSpeculativeDecoder.""" def __init__( self, target_model_path: str, draft_model_path: Optional[str] = None, block_size: int = 16, device: str = "metal", ): """Initialize server with target and optional draft model. Args: target_model_path: Path or HF ID of MLX target model draft_model_path: Path or HF ID of converted DFlash drafter block_size: Draft block size device: MLX device """ print(f"[Server] Loading target model: {target_model_path}...") self.loaded_target = load_target_model(target_model_path) if draft_model_path: print(f"[Server] Loading DFlash drafter: {draft_model_path}...") self.draft_model, self.draft_config = load_mlx_dflash(draft_model_path) else: # Try to auto-resolve draft model from .convert import _infer_target_model inferred = _infer_target_model(target_model_path) if inferred and inferred != "unknown": print(f"[Server] Auto-resolved drafter: {inferred}") # Look up in registry... self.draft_model, self.draft_config = None, None else: print("[Server] No draft model — will use baseline generation") self.draft_model, self.draft_config = None, None if self.draft_model is not None: self.decoder = DFlashSpeculativeDecoder( target_model=self.loaded_target, draft_model=self.draft_model, tokenizer=self.loaded_target.tokenizer, block_size=block_size, device=device, ) self.mode = "dflash" else: self.decoder = None self.mode = "baseline" # Metrics self.request_count = 0 self.total_tokens = 0 self.total_time = 0.0 self.recent_requests: List[Dict] = [] def health(self) -> Dict[str, Any]: return {"status": "ok", "mode": self.mode, "model": self.loaded_target.requested_model} def models(self) -> Dict[str, Any]: return { "object": "list", "data": [{ "id": self.loaded_target.requested_model, "object": "model", "owned_by": "dflash-mlx-universal", }] } def metrics(self) -> Dict[str, Any]: avg_tok_s = self.total_tokens / self.total_time if self.total_time > 0 else 0 return { "request_count": self.request_count, "total_tokens": self.total_tokens, "avg_tokens_per_sec": avg_tok_s, "recent_requests": self.recent_requests[-32:], "mode": self.mode, } def _update_metrics(self, num_tokens: int, elapsed: float): self.request_count += 1 self.total_tokens += num_tokens self.total_time += elapsed self.recent_requests.append({ "timestamp": time.time(), "tokens": num_tokens, "time_sec": elapsed, "tok_s": num_tokens / elapsed if elapsed > 0 else 0, }) if len(self.recent_requests) > 32: self.recent_requests = self.recent_requests[-32:] def chat_completions( self, messages: List[Dict[str, str]], max_tokens: int = 1024, temperature: float = 0.0, stream: bool = False, stop: Optional[List[str]] = None, ) -> Dict[str, Any] | Any: """Handle chat completion request. Returns dict for non-streaming, generator for streaming. """ # Build prompt from messages prompt = self._messages_to_prompt(messages) if stream: return self._stream_chat(prompt, max_tokens, temperature, stop) # Non-streaming start = time.time() if self.mode == "dflash" and self.decoder is not None: output = self.decoder.generate( prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop_strings=stop, ) else: # Baseline mlx_lm generation from mlx_lm.utils import generate as mlx_generate output = mlx_generate( model=self.loaded_target.model, tokenizer=self.loaded_target.tokenizer, prompt=prompt, max_tokens=max_tokens, temp=temperature, ) elapsed = time.time() - start num_tokens = len(self.loaded_target.tokenizer.encode(output)) self._update_metrics(num_tokens, elapsed) return { "id": f"chatcmpl-{int(time.time()*1000)}", "object": "chat.completion", "created": int(time.time()), "model": self.loaded_target.requested_model, "choices": [{ "index": 0, "message": { "role": "assistant", "content": output, }, "finish_reason": "stop", }], "usage": { "prompt_tokens": len(self.loaded_target.tokenizer.encode(prompt)), "completion_tokens": num_tokens, "total_tokens": len(self.loaded_target.tokenizer.encode(prompt)) + num_tokens, } } def _stream_chat(self, prompt: str, max_tokens: int, temperature: float, stop): """Generator for streaming SSE chunks.""" def event(data: Dict) -> str: return f"data: {json.dumps(data)}\n\n" # Yield initial role yield event({ "id": f"chatcmpl-{int(time.time()*1000)}", "object": "chat.completion.chunk", "created": int(time.time()), "model": self.loaded_target.requested_model, "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], }) accumulated = "" if self.mode == "dflash" and self.decoder is not None: # Use streaming generate for chunk in self.decoder.generate( prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop_strings=stop, stream=True, ): accumulated += chunk yield event({ "id": f"chatcmpl-{int(time.time()*1000)}", "object": "chat.completion.chunk", "created": int(time.time()), "model": self.loaded_target.requested_model, "choices": [{"index": 0, "delta": {"content": chunk}, "finish_reason": None}], }) else: # Baseline: generate then stream word-by-word (not true streaming) from mlx_lm.utils import generate as mlx_generate output = mlx_generate( model=self.loaded_target.model, tokenizer=self.loaded_target.tokenizer, prompt=prompt, max_tokens=max_tokens, temp=temperature, ) for word in output.split(" "): chunk = word + " " accumulated += chunk yield event({ "id": f"chatcmpl-{int(time.time()*1000)}", "object": "chat.completion.chunk", "created": int(time.time()), "model": self.loaded_target.requested_model, "choices": [{"index": 0, "delta": {"content": chunk}, "finish_reason": None}], }) # Final chunk yield event({ "id": f"chatcmpl-{int(time.time()*1000)}", "object": "chat.completion.chunk", "created": int(time.time()), "model": self.loaded_target.requested_model, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], }) yield "data: [DONE]\n\n" def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: """Convert OpenAI messages format to prompt string.""" # Try chat template tokenizer = self.loaded_target.tokenizer if hasattr(tokenizer, "apply_chat_template"): try: return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) except Exception: pass # Fallback: simple concatenation prompt = "" for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if role == "system": prompt += f"System: {content}\n" elif role == "user": prompt += f"User: {content}\n" elif role == "assistant": prompt += f"Assistant: {content}\n" prompt += "Assistant: " return prompt def create_app(target_model: str, draft_model: Optional[str] = None, block_size: int = 16): """Create a Flask/FastAPI-style app for serving.""" try: from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse app = FastAPI(title="DFlash MLX Server") server = DFlashServer(target_model, draft_model, block_size) @app.get("/health") async def health(): return server.health() @app.get("/v1/models") async def models(): return server.models() @app.get("/metrics") async def metrics(): return server.metrics() @app.post("/v1/chat/completions") async def chat_completions(request: Request): body = await request.json() messages = body.get("messages", []) max_tokens = body.get("max_tokens", 1024) temperature = body.get("temperature", 0.0) stream = body.get("stream", False) stop = body.get("stop", None) result = server.chat_completions( messages=messages, max_tokens=max_tokens, temperature=temperature, stream=stream, stop=stop, ) if stream: return StreamingResponse(result, media_type="text/event-stream") return result return app except ImportError: print("[Server] FastAPI not installed. Install with: pip install fastapi uvicorn") # Fallback: simple HTTP server from http.server import BaseHTTPRequestHandler, HTTPServer import threading class Handler(BaseHTTPRequestHandler): server_instance = None def do_GET(self): if self.path == "/health": self._json_response(200, self.server_instance.health()) elif self.path == "/v1/models": self._json_response(200, self.server_instance.models()) elif self.path == "/metrics": self._json_response(200, self.server_instance.metrics()) else: self._json_response(404, {"error": "Not found"}) def do_POST(self): if self.path == "/v1/chat/completions": content_len = int(self.headers.get("Content-Length", 0)) body = json.loads(self.rfile.read(content_len)) result = self.server_instance.chat_completions( messages=body.get("messages", []), max_tokens=body.get("max_tokens", 1024), temperature=body.get("temperature", 0.0), stream=False, stop=body.get("stop", None), ) self._json_response(200, result) else: self._json_response(404, {"error": "Not found"}) def _json_response(self, status: int, data: Dict): self.send_response(status) self.send_header("Content-Type", "application/json") self.end_headers() self.wfile.write(json.dumps(data).encode()) Handler.server_instance = DFlashServer(target_model, draft_model, block_size) return Handler def main(): import argparse parser = argparse.ArgumentParser(description="DFlash MLX OpenAI-compatible server") parser.add_argument("--target", required=True, help="Target model path or HF ID") parser.add_argument("--draft", default=None, help="Draft model path or HF ID") parser.add_argument("--block-size", type=int, default=16) parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--device", default="metal") args = parser.parse_args() server = DFlashServer(args.target, args.draft, args.block_size, args.device) try: import uvicorn from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse app = FastAPI() @app.get("/health") async def health(): return server.health() @app.get("/v1/models") async def models(): return server.models() @app.get("/metrics") async def metrics(): return server.metrics() @app.post("/v1/chat/completions") async def chat_completions(request: Request): body = await request.json() result = server.chat_completions( messages=body.get("messages", []), max_tokens=body.get("max_tokens", 1024), temperature=body.get("temperature", 0.0), stream=body.get("stream", False), stop=body.get("stop", None), ) if body.get("stream", False): return StreamingResponse(result, media_type="text/event-stream") return result print(f"[Server] Starting FastAPI on http://{args.host}:{args.port}") uvicorn.run(app, host=args.host, port=args.port) except ImportError: print("[Server] FastAPI/uvicorn not available, using simple HTTP server") from http.server import HTTPServer handler = create_app(args.target, args.draft, args.block_size) httpd = HTTPServer((args.host, args.port), handler) print(f"[Server] Starting simple HTTP on http://{args.host}:{args.port}") httpd.serve_forever() if __name__ == "__main__": main()