| from fastapi import FastAPI, HTTPException |
| import os |
| from dotenv import load_dotenv |
|
|
| |
| from utils.chat_request import ChatRequest |
| from utils.chat_response import create_chat_response, ChatResponse |
| from utils.model import check_model, initialize_pipeline, download_model, DownloadRequest |
|
|
| |
| model_name = None |
| pipe = None |
| tokenizer = None |
|
|
| |
| app = FastAPI(title="HF-Model-Runner API", version="0.0.1") |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| """ |
| 应用启动时初始化 pipeline |
| """ |
| global pipe, tokenizer, model_name |
| |
| |
| load_dotenv() |
| |
| |
| default_model = os.getenv("DEFAULT_MODEL_NAME", "unsloth/functiongemma-270m-it") |
| print(f"应用启动,正在初始化模型: {default_model}") |
| |
| try: |
| pipe, tokenizer, success = initialize_pipeline(default_model) |
| if success: |
| model_name = default_model |
| print(f"✓ 模型 {default_model} 初始化成功") |
| else: |
| print(f"✗ 模型 {default_model} 初始化失败") |
| except Exception as e: |
| print(f"✗ 启动时模型初始化失败: {e}") |
|
|
| @app.get("/") |
| async def read_root(): |
| return {"message": "Welcome to HF-Model-Runner API! Visit /docs for API documentation."} |
|
|
| @app.post("/v1/download") |
| async def download_model_endpoint(request: DownloadRequest): |
| """ |
| 下载指定的 HuggingFace 模型 |
| """ |
| global pipe, tokenizer, model_name |
| |
| try: |
| success, message = download_model(request.model) |
| if success: |
| |
| pipe, tokenizer, init_success = initialize_pipeline(request.model) |
| if init_success: |
| model_name = request.model |
| return { |
| "status": "success", |
| "message": message, |
| "loaded": True, |
| "current_model": model_name |
| } |
| else: |
| return { |
| "status": "success", |
| "message": message, |
| "loaded": False, |
| "error": "模型下载成功但初始化失败" |
| } |
| else: |
| raise HTTPException(status_code=500, detail=message) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/v1/chat/completions", response_model=ChatResponse) |
| async def chat_completions(request: ChatRequest): |
| """ |
| OpenAI 兼容的聊天完成接口 |
| """ |
| global pipe, tokenizer, model_name |
| |
| |
| if request.model != model_name: |
| pipe, tokenizer, success = initialize_pipeline(request.model) |
| if not success: |
| raise HTTPException(status_code=500, detail="模型初始化失败") |
| model_name = request.model |
| |
| try: |
| return create_chat_response(request, pipe, tokenizer) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
| |