| """ |
| Synapse-Base Inference API |
| FastAPI server for chess move prediction |
| Optimized for HF Spaces CPU environment |
| """ |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| import time |
| import logging |
| from typing import Optional |
|
|
| from engine import SynapseEngine |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| app = FastAPI( |
| title="Synapse-Base Inference API", |
| description="High-performance chess engine powered by 38M parameter neural network", |
| version="3.0.0" |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| engine = None |
|
|
|
|
| |
| class MoveRequest(BaseModel): |
| fen: str = Field(..., description="Board position in FEN notation") |
| depth: Optional[int] = Field(3, ge=1, le=5, description="Search depth (1-5)") |
| time_limit: Optional[int] = Field(5000, ge=1000, le=30000, description="Time limit in ms") |
|
|
|
|
| class MoveResponse(BaseModel): |
| best_move: str |
| evaluation: float |
| depth_searched: int |
| nodes_evaluated: int |
| time_taken: int |
| pv: Optional[list] = None |
|
|
|
|
| class HealthResponse(BaseModel): |
| status: str |
| model_loaded: bool |
| version: str |
|
|
|
|
| |
| @app.on_event("startup") |
| async def startup_event(): |
| """Load model on startup""" |
| global engine |
| |
| logger.info("π Starting Synapse-Base Inference API...") |
| |
| try: |
| engine = SynapseEngine( |
| model_path="/app/models/synapse_base.onnx", |
| num_threads=2 |
| ) |
| logger.info("β
Model loaded successfully") |
| logger.info(f"π Model size: {engine.get_model_size():.2f} MB") |
| |
| except Exception as e: |
| logger.error(f"β Failed to load model: {e}") |
| raise |
|
|
|
|
| |
| @app.get("/health", response_model=HealthResponse) |
| async def health_check(): |
| """Health check endpoint""" |
| return { |
| "status": "healthy" if engine is not None else "unhealthy", |
| "model_loaded": engine is not None, |
| "version": "3.0.0" |
| } |
|
|
|
|
| |
| @app.post("/get-move", response_model=MoveResponse) |
| async def get_move(request: MoveRequest): |
| """ |
| Get best move for given position |
| |
| Args: |
| request: MoveRequest with FEN, depth, and time_limit |
| |
| Returns: |
| MoveResponse with best_move and evaluation |
| """ |
| |
| if engine is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
| |
| |
| if not engine.validate_fen(request.fen): |
| raise HTTPException(status_code=400, detail="Invalid FEN string") |
| |
| |
| start_time = time.time() |
| |
| try: |
| |
| result = engine.get_best_move( |
| fen=request.fen, |
| depth=request.depth, |
| time_limit=request.time_limit |
| ) |
| |
| |
| time_taken = int((time.time() - start_time) * 1000) |
| |
| |
| logger.info( |
| f"Move: {result['best_move']} | " |
| f"Eval: {result['evaluation']:.3f} | " |
| f"Depth: {result['depth_searched']} | " |
| f"Nodes: {result['nodes_evaluated']} | " |
| f"Time: {time_taken}ms" |
| ) |
| |
| return MoveResponse( |
| best_move=result['best_move'], |
| evaluation=result['evaluation'], |
| depth_searched=result['depth_searched'], |
| nodes_evaluated=result['nodes_evaluated'], |
| time_taken=time_taken, |
| pv=result.get('pv', None) |
| ) |
| |
| except Exception as e: |
| logger.error(f"Error processing move: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| |
| @app.get("/") |
| async def root(): |
| """Root endpoint with API info""" |
| return { |
| "name": "Synapse-Base Inference API", |
| "version": "3.0.0", |
| "model": "38.1M parameters", |
| "architecture": "CNN-Transformer Hybrid", |
| "endpoints": { |
| "POST /get-move": "Get best move for position", |
| "GET /health": "Health check", |
| "GET /docs": "API documentation" |
| } |
| } |
|
|
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| |
| uvicorn.run( |
| app, |
| host="0.0.0.0", |
| port=7860, |
| log_level="info", |
| access_log=True |
| ) |