| import logging |
| import os |
| from contextlib import asynccontextmanager |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
|
|
| from inference_server.session_manager import SessionManager |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| session_manager = SessionManager() |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Handle app startup and shutdown.""" |
| logger.info("🚀 RobotHub Inference Server starting up...") |
| yield |
| logger.info("🔄 RobotHub Inference Server shutting down...") |
| await session_manager.cleanup_all_sessions() |
| logger.info("✅ RobotHub Inference Server shutdown complete") |
|
|
|
|
| |
| app = FastAPI( |
| title="RobotHub Inference Server", |
| description="Multi-Policy Model Inference Server for Real-time Robot Control", |
| version="1.0.0", |
| lifespan=lifespan, |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
| class CreateSessionRequest(BaseModel): |
| session_id: str |
| policy_path: str |
| transport_server_url: str |
| camera_names: list[str] = ["front"] |
| workspace_id: str | None = None |
| policy_type: str = "act" |
| language_instruction: str | None = None |
|
|
|
|
| class CreateSessionResponse(BaseModel): |
| workspace_id: str |
| camera_room_ids: dict[str, str] |
| joint_input_room_id: str |
| joint_output_room_id: str |
|
|
|
|
| class SessionStatusResponse(BaseModel): |
| session_id: str |
| status: str |
| policy_path: str |
| policy_type: str |
| camera_names: list[str] |
| workspace_id: str |
| rooms: dict |
| stats: dict |
| inference_stats: dict | None = None |
| error_message: str | None = None |
|
|
|
|
| |
| @app.get("/", tags=["Health"]) |
| async def root(): |
| """Health check endpoint.""" |
| return {"message": "RobotHub Inference Server is running", "status": "healthy"} |
|
|
|
|
| @app.get("/health", tags=["Health"]) |
| async def health_check(): |
| """Detailed health check.""" |
| return { |
| "status": "healthy", |
| "active_sessions": len(session_manager.sessions), |
| "session_ids": list(session_manager.sessions.keys()), |
| } |
|
|
|
|
| |
| @app.post("/sessions", response_model=CreateSessionResponse, tags=["Sessions"]) |
| async def create_session(request: CreateSessionRequest): |
| """ |
| Create a new inference session. |
| |
| If workspace_id is provided, all rooms will be created in that workspace. |
| If workspace_id is not provided, a new workspace will be generated automatically. |
| All rooms for a session (cameras + joints) are always created in the same workspace. |
| """ |
| try: |
| room_ids = await session_manager.create_session( |
| session_id=request.session_id, |
| policy_path=request.policy_path, |
| camera_names=request.camera_names, |
| transport_server_url=request.transport_server_url, |
| workspace_id=request.workspace_id, |
| policy_type=request.policy_type, |
| language_instruction=request.language_instruction, |
| ) |
| return CreateSessionResponse(**room_ids) |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) |
| except Exception as e: |
| logger.exception(f"Failed to create session {request.session_id}") |
| raise HTTPException(status_code=500, detail=f"Failed to create session: {e!s}") |
|
|
|
|
| @app.get("/sessions", response_model=list[SessionStatusResponse], tags=["Sessions"]) |
| async def list_sessions(): |
| """List all sessions.""" |
| sessions = await session_manager.list_sessions() |
| return [SessionStatusResponse(**session) for session in sessions] |
|
|
|
|
| |
| @app.post("/sessions/{session_id}/start", tags=["Control"]) |
| async def start_inference(session_id: str): |
| """Start inference for a session.""" |
| try: |
| await session_manager.start_inference(session_id) |
| except KeyError: |
| raise HTTPException(status_code=404, detail=f"Session {session_id} not found") |
| except Exception as e: |
| logger.exception(f"Failed to start inference for session {session_id}") |
| raise HTTPException(status_code=500, detail=f"Failed to start inference: {e!s}") |
| else: |
| return {"message": f"Inference started for session {session_id}"} |
|
|
|
|
| @app.post("/sessions/{session_id}/stop", tags=["Control"]) |
| async def stop_inference(session_id: str): |
| """Stop inference for a session.""" |
| try: |
| await session_manager.stop_inference(session_id) |
| except KeyError: |
| raise HTTPException(status_code=404, detail=f"Session {session_id} not found") |
| else: |
| return {"message": f"Inference started for session {session_id}"} |
|
|
|
|
| @app.post("/sessions/{session_id}/restart", tags=["Control"]) |
| async def restart_inference(session_id: str): |
| """Restart inference for a session.""" |
| try: |
| await session_manager.restart_inference(session_id) |
| except KeyError: |
| raise HTTPException(status_code=404, detail=f"Session {session_id} not found") |
| except Exception as e: |
| logger.exception(f"Failed to restart inference for session {session_id}") |
| raise HTTPException( |
| status_code=500, detail=f"Failed to restart inference: {e!s}" |
| ) |
| else: |
| return {"message": f"Inference restarted for session {session_id}"} |
|
|
|
|
| @app.delete("/sessions/{session_id}", tags=["Sessions"]) |
| async def delete_session(session_id: str): |
| """Delete a session.""" |
| try: |
| await session_manager.delete_session(session_id) |
| except KeyError: |
| raise HTTPException(status_code=404, detail=f"Session {session_id} not found") |
| else: |
| return {"message": f"Session {session_id} deleted"} |
|
|
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
|
|
| port = int(os.environ.get("PORT", "8001")) |
| uvicorn.run( |
| "inference_server.main:app", |
| host="0.0.0.0", |
| port=port, |
| reload=False, |
| log_level="info", |
| ) |
|
|