| import asyncio |
| import json |
| import time |
|
|
| from typing import Optional, List |
|
|
| from pydantic import BaseModel, Field |
|
|
| from starlette.responses import StreamingResponse |
| from fastapi import FastAPI, HTTPException, Request |
|
|
| app = FastAPI(title="OpenAI-compatible API") |
|
|
|
|
| |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") |
| model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2") |
|
|
|
|
| |
| class Message(BaseModel): |
| role: str |
| content: str |
|
|
|
|
| class ChatCompletionRequest(BaseModel): |
| model: Optional[str] = "mock-gpt-model" |
| messages: List[Message] |
| max_tokens: Optional[int] = 512 |
| temperature: Optional[float] = 0.1 |
| stream: Optional[bool] = False |
|
|
|
|
| async def _resp_async_generator(text_resp: str, model: str): |
| tokens = text_resp.split(" ") |
|
|
| for i, token in enumerate(tokens): |
| chunk = { |
| "id": i, |
| "object": "chat.completion.chunk", |
| "created": time.time(), |
| "model": model, |
| "choices": [{"delta": {"content": token + " "}}], |
| } |
| yield f"data: {json.dumps(chunk)}\n\n" |
| await asyncio.sleep(0.05) |
| yield "data: [DONE]\n\n" |
|
|
| @app.post("/chat/completions") |
| async def chat_completions(request: ChatCompletionRequest): |
| if not request.messages: |
| raise HTTPException(status_code=400, detail="No messages provided.") |
|
|
| |
| prompt = "" |
| for msg in request.messages: |
| if msg.role == "user": |
| prompt += f"User: {msg.content}\n" |
| elif msg.role == "assistant": |
| prompt += f"Assistant: {msg.content}\n" |
| prompt += "Assistant:" |
|
|
| |
| inputs = tokenizer(prompt, return_tensors="pt") |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=request.max_tokens, |
| temperature=request.temperature, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| assistant_reply = generated_text[len(prompt):].strip() |
|
|
| if request.stream: |
| return StreamingResponse( |
| _resp_async_generator(assistant_reply, request.model), |
| media_type="text/event-stream" |
| ) |
|
|
| return { |
| "id": "1337", |
| "object": "chat.completion", |
| "created": time.time(), |
| "model": request.model, |
| "choices": [{"message": Message(role="assistant", content=assistant_reply)}], |
| } |
|
|
|
|
|
|
|
|
| |
| |
| |
|
|