| from fastapi import FastAPI, Request, HTTPException, Depends, Header |
| from pydantic import BaseModel, Field |
| from sentence_transformers import SentenceTransformer |
| from typing import Union, List |
| import numpy as np |
| import logging, os |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| async def check_authorization(authorization: str = Header(..., alias="Authorization")): |
| |
| if not authorization.startswith("Bearer "): |
| raise HTTPException(status_code=401, detail="Invalid Authorization header format") |
| |
| token = authorization[len("Bearer "):] |
| if token != os.environ.get("AUTHORIZATION"): |
| raise HTTPException(status_code=401, detail="Unauthorized access") |
| return token |
|
|
| app = FastAPI() |
|
|
| try: |
| |
| model = SentenceTransformer("BAAI/bge-large-zh-v1.5") |
| except Exception as e: |
| logger.error(f"Failed to load model: {e}") |
| raise HTTPException(status_code=500, detail="Model loading failed") |
|
|
| class EmbeddingRequest(BaseModel): |
| input: Union[str, List[str]] |
|
|
| @app.post("/v1/embeddings") |
| async def embeddings(request: EmbeddingRequest, authorization: str = Depends(check_authorization)): |
| input_data = request.input |
| |
| inputs = [input_data] if isinstance(input_data, str) else input_data |
|
|
| if not inputs: |
| return { ... } |
|
|
| |
| embeddings = model.encode(inputs, normalize_embeddings=True) |
|
|
| |
| data_entries = [] |
| for idx, embed in enumerate(embeddings): |
| data_entries.append({ |
| "object": "embedding", |
| "embedding": embed.tolist(), |
| "index": idx |
| }) |
|
|
| return { |
| "object": "list", |
| "data": data_entries, |
| "model": "BAAI/bge-large-zh-v1.5", |
| "usage": { |
| "prompt_tokens": sum(len(text) for text in inputs), |
| "total_tokens": sum(len(text) for text in inputs) |
| } |
| } |