tritesh's picture
Upload dflash_mlx/serve.py
834cedc verified
"""
OpenAI-compatible HTTP server for DFlash speculative decoding on MLX.
Supports:
- POST /v1/chat/completions (with streaming via SSE)
- POST /v1/completions
- GET /v1/models
- GET /health
- GET /metrics (DFlash-specific diagnostics)
Inspired by bstnxbt/dflash-mlx server architecture and Aryagm's OpenAI server.
"""
import json
import time
from typing import Any, Dict, List, Optional
from .speculative_decode import DFlashSpeculativeDecoder
from .adapters import load_target_model, LoadedTargetModel
from .convert import load_mlx_dflash
class DFlashServer:
"""OpenAI-compatible server wrapping a DFlashSpeculativeDecoder."""
def __init__(
self,
target_model_path: str,
draft_model_path: Optional[str] = None,
block_size: int = 16,
device: str = "metal",
):
"""Initialize server with target and optional draft model.
Args:
target_model_path: Path or HF ID of MLX target model
draft_model_path: Path or HF ID of converted DFlash drafter
block_size: Draft block size
device: MLX device
"""
print(f"[Server] Loading target model: {target_model_path}...")
self.loaded_target = load_target_model(target_model_path)
if draft_model_path:
print(f"[Server] Loading DFlash drafter: {draft_model_path}...")
self.draft_model, self.draft_config = load_mlx_dflash(draft_model_path)
else:
# Try to auto-resolve draft model
from .convert import _infer_target_model
inferred = _infer_target_model(target_model_path)
if inferred and inferred != "unknown":
print(f"[Server] Auto-resolved drafter: {inferred}")
# Look up in registry...
self.draft_model, self.draft_config = None, None
else:
print("[Server] No draft model — will use baseline generation")
self.draft_model, self.draft_config = None, None
if self.draft_model is not None:
self.decoder = DFlashSpeculativeDecoder(
target_model=self.loaded_target,
draft_model=self.draft_model,
tokenizer=self.loaded_target.tokenizer,
block_size=block_size,
device=device,
)
self.mode = "dflash"
else:
self.decoder = None
self.mode = "baseline"
# Metrics
self.request_count = 0
self.total_tokens = 0
self.total_time = 0.0
self.recent_requests: List[Dict] = []
def health(self) -> Dict[str, Any]:
return {"status": "ok", "mode": self.mode, "model": self.loaded_target.requested_model}
def models(self) -> Dict[str, Any]:
return {
"object": "list",
"data": [{
"id": self.loaded_target.requested_model,
"object": "model",
"owned_by": "dflash-mlx-universal",
}]
}
def metrics(self) -> Dict[str, Any]:
avg_tok_s = self.total_tokens / self.total_time if self.total_time > 0 else 0
return {
"request_count": self.request_count,
"total_tokens": self.total_tokens,
"avg_tokens_per_sec": avg_tok_s,
"recent_requests": self.recent_requests[-32:],
"mode": self.mode,
}
def _update_metrics(self, num_tokens: int, elapsed: float):
self.request_count += 1
self.total_tokens += num_tokens
self.total_time += elapsed
self.recent_requests.append({
"timestamp": time.time(),
"tokens": num_tokens,
"time_sec": elapsed,
"tok_s": num_tokens / elapsed if elapsed > 0 else 0,
})
if len(self.recent_requests) > 32:
self.recent_requests = self.recent_requests[-32:]
def chat_completions(
self,
messages: List[Dict[str, str]],
max_tokens: int = 1024,
temperature: float = 0.0,
stream: bool = False,
stop: Optional[List[str]] = None,
) -> Dict[str, Any] | Any:
"""Handle chat completion request.
Returns dict for non-streaming, generator for streaming.
"""
# Build prompt from messages
prompt = self._messages_to_prompt(messages)
if stream:
return self._stream_chat(prompt, max_tokens, temperature, stop)
# Non-streaming
start = time.time()
if self.mode == "dflash" and self.decoder is not None:
output = self.decoder.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
stop_strings=stop,
)
else:
# Baseline mlx_lm generation
from mlx_lm.utils import generate as mlx_generate
output = mlx_generate(
model=self.loaded_target.model,
tokenizer=self.loaded_target.tokenizer,
prompt=prompt,
max_tokens=max_tokens,
temp=temperature,
)
elapsed = time.time() - start
num_tokens = len(self.loaded_target.tokenizer.encode(output))
self._update_metrics(num_tokens, elapsed)
return {
"id": f"chatcmpl-{int(time.time()*1000)}",
"object": "chat.completion",
"created": int(time.time()),
"model": self.loaded_target.requested_model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": output,
},
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": len(self.loaded_target.tokenizer.encode(prompt)),
"completion_tokens": num_tokens,
"total_tokens": len(self.loaded_target.tokenizer.encode(prompt)) + num_tokens,
}
}
def _stream_chat(self, prompt: str, max_tokens: int, temperature: float, stop):
"""Generator for streaming SSE chunks."""
def event(data: Dict) -> str:
return f"data: {json.dumps(data)}\n\n"
# Yield initial role
yield event({
"id": f"chatcmpl-{int(time.time()*1000)}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.loaded_target.requested_model,
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
})
accumulated = ""
if self.mode == "dflash" and self.decoder is not None:
# Use streaming generate
for chunk in self.decoder.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
stop_strings=stop,
stream=True,
):
accumulated += chunk
yield event({
"id": f"chatcmpl-{int(time.time()*1000)}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.loaded_target.requested_model,
"choices": [{"index": 0, "delta": {"content": chunk}, "finish_reason": None}],
})
else:
# Baseline: generate then stream word-by-word (not true streaming)
from mlx_lm.utils import generate as mlx_generate
output = mlx_generate(
model=self.loaded_target.model,
tokenizer=self.loaded_target.tokenizer,
prompt=prompt,
max_tokens=max_tokens,
temp=temperature,
)
for word in output.split(" "):
chunk = word + " "
accumulated += chunk
yield event({
"id": f"chatcmpl-{int(time.time()*1000)}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.loaded_target.requested_model,
"choices": [{"index": 0, "delta": {"content": chunk}, "finish_reason": None}],
})
# Final chunk
yield event({
"id": f"chatcmpl-{int(time.time()*1000)}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.loaded_target.requested_model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
})
yield "data: [DONE]\n\n"
def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
"""Convert OpenAI messages format to prompt string."""
# Try chat template
tokenizer = self.loaded_target.tokenizer
if hasattr(tokenizer, "apply_chat_template"):
try:
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
except Exception:
pass
# Fallback: simple concatenation
prompt = ""
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
prompt += f"System: {content}\n"
elif role == "user":
prompt += f"User: {content}\n"
elif role == "assistant":
prompt += f"Assistant: {content}\n"
prompt += "Assistant: "
return prompt
def create_app(target_model: str, draft_model: Optional[str] = None, block_size: int = 16):
"""Create a Flask/FastAPI-style app for serving."""
try:
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
app = FastAPI(title="DFlash MLX Server")
server = DFlashServer(target_model, draft_model, block_size)
@app.get("/health")
async def health():
return server.health()
@app.get("/v1/models")
async def models():
return server.models()
@app.get("/metrics")
async def metrics():
return server.metrics()
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
body = await request.json()
messages = body.get("messages", [])
max_tokens = body.get("max_tokens", 1024)
temperature = body.get("temperature", 0.0)
stream = body.get("stream", False)
stop = body.get("stop", None)
result = server.chat_completions(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
stream=stream,
stop=stop,
)
if stream:
return StreamingResponse(result, media_type="text/event-stream")
return result
return app
except ImportError:
print("[Server] FastAPI not installed. Install with: pip install fastapi uvicorn")
# Fallback: simple HTTP server
from http.server import BaseHTTPRequestHandler, HTTPServer
import threading
class Handler(BaseHTTPRequestHandler):
server_instance = None
def do_GET(self):
if self.path == "/health":
self._json_response(200, self.server_instance.health())
elif self.path == "/v1/models":
self._json_response(200, self.server_instance.models())
elif self.path == "/metrics":
self._json_response(200, self.server_instance.metrics())
else:
self._json_response(404, {"error": "Not found"})
def do_POST(self):
if self.path == "/v1/chat/completions":
content_len = int(self.headers.get("Content-Length", 0))
body = json.loads(self.rfile.read(content_len))
result = self.server_instance.chat_completions(
messages=body.get("messages", []),
max_tokens=body.get("max_tokens", 1024),
temperature=body.get("temperature", 0.0),
stream=False,
stop=body.get("stop", None),
)
self._json_response(200, result)
else:
self._json_response(404, {"error": "Not found"})
def _json_response(self, status: int, data: Dict):
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(data).encode())
Handler.server_instance = DFlashServer(target_model, draft_model, block_size)
return Handler
def main():
import argparse
parser = argparse.ArgumentParser(description="DFlash MLX OpenAI-compatible server")
parser.add_argument("--target", required=True, help="Target model path or HF ID")
parser.add_argument("--draft", default=None, help="Draft model path or HF ID")
parser.add_argument("--block-size", type=int, default=16)
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--device", default="metal")
args = parser.parse_args()
server = DFlashServer(args.target, args.draft, args.block_size, args.device)
try:
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
app = FastAPI()
@app.get("/health")
async def health():
return server.health()
@app.get("/v1/models")
async def models():
return server.models()
@app.get("/metrics")
async def metrics():
return server.metrics()
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
body = await request.json()
result = server.chat_completions(
messages=body.get("messages", []),
max_tokens=body.get("max_tokens", 1024),
temperature=body.get("temperature", 0.0),
stream=body.get("stream", False),
stop=body.get("stop", None),
)
if body.get("stream", False):
return StreamingResponse(result, media_type="text/event-stream")
return result
print(f"[Server] Starting FastAPI on http://{args.host}:{args.port}")
uvicorn.run(app, host=args.host, port=args.port)
except ImportError:
print("[Server] FastAPI/uvicorn not available, using simple HTTP server")
from http.server import HTTPServer
handler = create_app(args.target, args.draft, args.block_size)
httpd = HTTPServer((args.host, args.port), handler)
print(f"[Server] Starting simple HTTP on http://{args.host}:{args.port}")
httpd.serve_forever()
if __name__ == "__main__":
main()