| |
| """ |
| FastAPI servis giriş noktası (app.py) |
| - Startup'ta modeli yükler (sıcak bekletir). |
| - /infer ile tahmin, /health ve /model_info ile kontrol sağlar. |
| - handler.py dosyası aynı klasörde olmalıdır. |
| """ |
|
|
| import os |
| import asyncio |
| from concurrent.futures import ThreadPoolExecutor |
| from typing import Any, Dict, Optional |
|
|
| from fastapi import FastAPI, Body, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
|
|
| import handler as pulse_handler |
|
|
| |
| HOST = os.getenv("HOST", "0.0.0.0") |
| PORT = int(os.getenv("PORT", "8000")) |
| MAX_WORKERS = int(os.getenv("MAX_WORKERS", "4")) |
|
|
| |
| os.environ.setdefault("HF_MODEL_ID", "CanerDedeoglu/Rapid_ECG") |
|
|
| |
| executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) |
| endpoint = None |
|
|
| app = FastAPI(title="Rapid ECG Inference API", version="1.0.0") |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=os.getenv("CORS_ALLOW_ORIGINS", "*").split(","), |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| class InferenceRequest(BaseModel): |
| |
| inputs: Optional[Dict[str, Any]] = None |
|
|
| message: Optional[str] = None |
| image: Optional[Any] = None |
| image_url: Optional[str] = None |
| img: Optional[Any] = None |
|
|
| temperature: Optional[float] = None |
| top_p: Optional[float] = None |
| max_new_tokens: Optional[int] = None |
| repetition_penalty: Optional[float] = None |
| conv_mode: Optional[str] = None |
| det_seed: Optional[int] = None |
|
|
| def _ensure_initialized(): |
| """Modeli (bir kere) yükle ve EndpointHandler hazırla.""" |
| global endpoint |
| if pulse_handler.model_initialized and endpoint is not None: |
| return |
| ok = pulse_handler.initialize_model() |
| if not ok: |
| raise RuntimeError("Model initialization failed") |
| endpoint = pulse_handler.EndpointHandler( |
| model_dir=os.getenv("HF_MODEL_ID", "CanerDedeoglu/Rapid_ECG") |
| ) |
|
|
| def _merge_payload(req: InferenceRequest) -> Dict[str, Any]: |
| """HF 'inputs' ile diğer alanları birleştirir.""" |
| payload = dict(req.inputs or {}) |
| for k in ["message","image","image_url","img", |
| "temperature","top_p","max_new_tokens", |
| "repetition_penalty","conv_mode","det_seed"]: |
| v = getattr(req, k) |
| if v is not None: |
| payload[k] = v |
| return payload |
|
|
| async def _run_inference(payload: Dict[str, Any]) -> Dict[str, Any]: |
| """Blocking handler çağrısını thread pool'da çalıştır.""" |
| loop = asyncio.get_running_loop() |
| def _call(): |
| return endpoint({"inputs": payload}) |
| return await loop.run_in_executor(executor, _call) |
|
|
| |
| @app.on_event("startup") |
| async def on_startup(): |
| _ensure_initialized() |
|
|
| |
| @app.get("/health") |
| async def health(): |
| return pulse_handler.health_check() |
|
|
| @app.get("/model_info") |
| async def model_info(): |
| _ensure_initialized() |
| return pulse_handler.get_model_info() |
|
|
| @app.post("/infer") |
| async def infer(req: InferenceRequest = Body(...)): |
| _ensure_initialized() |
| payload = _merge_payload(req) |
| if not payload.get("message"): |
| raise HTTPException(400, "Missing 'message'") |
| if not (payload.get("image") or payload.get("image_url") or payload.get("img")): |
| raise HTTPException(400, "Missing 'image' / 'image_url' / 'img'") |
| result = await _run_inference(payload) |
| if isinstance(result, dict) and result.get("error"): |
| raise HTTPException(500, result["error"]) |
| return result |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("app:app", host=HOST, port=PORT, reload=bool(int(os.getenv("RELOAD","0")))) |
|
|