| from fastapi import FastAPI, Depends, HTTPException, status,Request |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from sentence_transformers import SentenceTransformer |
| from pydantic import BaseModel, Field |
| from fastapi.middleware.cors import CORSMiddleware |
| import uvicorn |
| import tiktoken |
| import numpy as np |
| from scipy.interpolate import interp1d |
| from typing import List, Literal, Optional, Union,Dict |
| from sklearn.preprocessing import PolynomialFeatures |
| import torch |
| import os |
| import time |
|
|
|
|
| |
| sk_key = os.environ.get('sk-key', 'sk-aaabbbcccdddeeefffggghhhiiijjjkkk') |
|
|
| |
| app = FastAPI() |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| if torch.cuda.is_available(): |
| print('本次加载模型的设备为GPU: ', torch.cuda.get_device_name(0)) |
| else: |
| print('本次加载模型的设备为CPU.') |
| model = SentenceTransformer('Qwen/Qwen3-Embedding-4B',device=device) |
|
|
| |
| security = HTTPBearer() |
|
|
|
|
|
|
| class ChatMessage(BaseModel): |
| role: Literal["user", "assistant", "system"] |
| content: str |
|
|
|
|
| class DeltaMessage(BaseModel): |
| role: Optional[Literal["user", "assistant", "system"]] = None |
| content: Optional[str] = None |
|
|
| class ChatCompletionRequest(BaseModel): |
| model: str |
| messages: List[ChatMessage] |
| temperature: Optional[float] = None |
| top_p: Optional[float] = None |
| max_length: Optional[int] = None |
| stream: Optional[bool] = False |
|
|
|
|
| class ChatCompletionResponseChoice(BaseModel): |
| index: int |
| message: ChatMessage |
| finish_reason: Literal["stop", "length"] |
|
|
|
|
| class ChatCompletionResponseStreamChoice(BaseModel): |
| index: int |
| delta: DeltaMessage |
| finish_reason: Optional[Literal["stop", "length"]] |
|
|
| class ChatCompletionResponse(BaseModel): |
| model: str |
| object: Literal["chat.completion", "chat.completion.chunk"] |
| choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] |
| created: Optional[int] = Field(default_factory=lambda: int(time.time())) |
|
|
| class EmbeddingRequest(BaseModel): |
| input: List[str] |
| model: str |
|
|
| class EmbeddingResponse(BaseModel): |
| data: list |
| model: str |
| object: str |
| usage: dict |
|
|
| def num_tokens_from_string(string: str) -> int: |
| """Returns the number of tokens in a text string.""" |
| encoding = tiktoken.get_encoding('cl100k_base') |
| num_tokens = len(encoding.encode(string)) |
| return num_tokens |
|
|
| |
| def interpolate_vector(vector, target_length): |
| original_indices = np.arange(len(vector)) |
| target_indices = np.linspace(0, len(vector)-1, target_length) |
| f = interp1d(original_indices, vector, kind='linear') |
| return f(target_indices) |
|
|
| def expand_features(embedding, target_length): |
| poly = PolynomialFeatures(degree=2) |
| expanded_embedding = poly.fit_transform(embedding.reshape(1, -1)) |
| expanded_embedding = expanded_embedding.flatten() |
| if len(expanded_embedding) > target_length: |
| |
| expanded_embedding = expanded_embedding[:target_length] |
| elif len(expanded_embedding) < target_length: |
| |
| expanded_embedding = np.pad(expanded_embedding, (0, target_length - len(expanded_embedding))) |
| return expanded_embedding |
|
|
| @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) |
| async def create_chat_completion(request: ChatCompletionRequest, credentials: HTTPAuthorizationCredentials = Depends(security)): |
| if credentials.credentials != sk_key: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid authorization code", |
| ) |
| choice_data = ChatCompletionResponseChoice( |
| index=0, |
| message=ChatMessage(role="assistant", content='你说得对,但这个是向量模型不能对话'), |
| finish_reason="stop" |
| ) |
|
|
| return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") |
|
|
| @app.post("/v1/embeddings", response_model=EmbeddingResponse) |
| async def get_embeddings(http_request: Request, request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)): |
| client_host = http_request.client.host |
| headers = http_request.headers |
| print(f"Client IP: {client_host}") |
| print(f"Request headers: {headers}") |
| |
| if credentials.credentials != sk_key: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid authorization code", |
| ) |
| |
| |
| embeddings = [model.encode(text) for text in request.input] |
| |
|
|
| |
| |
| |
| embeddings = [expand_features(embedding, 1536) if len(embedding) < 1536 else embedding for embedding in embeddings] |
|
|
| |
| |
| embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings] |
| |
| embeddings = [embedding.tolist() for embedding in embeddings] |
| prompt_tokens = sum(len(text.split()) for text in request.input) |
| total_tokens = sum(num_tokens_from_string(text) for text in request.input) |
|
|
| |
| response = { |
| "data": [ |
| { |
| "embedding": embedding, |
| "index": index, |
| "object": "embedding" |
| } for index, embedding in enumerate(embeddings) |
| ], |
| "model": request.model, |
| "object": "list", |
| "usage": { |
| "prompt_tokens": prompt_tokens, |
| "total_tokens": total_tokens, |
| } |
| } |
|
|
| |
| return response |
|
|
| if __name__ == "__main__": |
| |
|
|
| uvicorn.run("localembedding:app", host='0.0.0.0', port=7860, workers=1) |
|
|