| import asyncio |
| import os |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException |
| from fastapi.responses import HTMLResponse |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from typing import List, Dict |
| import uvicorn |
| import json |
| import time |
|
|
| |
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
| class MockRequest(BaseModel): |
| """Définit la structure attendue pour le corps de la requête POST.""" |
| parameter: str |
| model: str = None |
| secret: str |
|
|
| class ConnectionManager: |
| """Gère les connexions WebSocket actives.""" |
| def __init__(self): |
| self.active_connections: List[WebSocket] = [] |
| |
| self.response_futures: Dict[str, asyncio.Future] = {} |
|
|
| async def connect(self, websocket: WebSocket): |
| """Accepte une nouvelle connexion WebSocket.""" |
| await websocket.accept() |
| self.active_connections.append(websocket) |
| print(f"Nouvelle connexion WebSocket. Total: {len(self.active_connections)}") |
|
|
| def disconnect(self, websocket: WebSocket): |
| """Ferme une connexion WebSocket.""" |
| self.active_connections.remove(websocket) |
| print(f"Déconnexion WebSocket. Total: {len(self.active_connections)}") |
|
|
| async def broadcast(self, message: str): |
| """Envoie un message à tous les clients connectés.""" |
| |
| if self.active_connections: |
| websocket = self.active_connections[0] |
| await websocket.send_text(message) |
| |
| future = asyncio.get_event_loop().create_future() |
| |
| client_id = str(id(websocket)) |
| self.response_futures[client_id] = future |
| return future |
| return None |
|
|
| manager = ConnectionManager() |
|
|
| def verify_secret(provided_secret: str) -> bool: |
| """Vérifie si le secret fourni correspond à celui de la variable d'environnement.""" |
| expected_secret = os.getenv("API_SECRET") |
| |
| if not expected_secret: |
| print("ATTENTION: Variable d'environnement API_SECRET non définie!") |
| return False |
| |
| return provided_secret == expected_secret |
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def root(): |
| """Serve the main HTML page.""" |
| try: |
| with open("static/index.html", "r", encoding="utf-8") as f: |
| return HTMLResponse(content=f.read()) |
| except FileNotFoundError: |
| raise HTTPException(status_code=404, detail="index.html not found") |
|
|
| @app.post("/v1/mock") |
| async def mock_endpoint(payload: MockRequest): |
| """ |
| Endpoint API qui prend un string et un secret, vérifie le secret, |
| puis transmet via WebSocket, attend une réponse et la retourne. |
| """ |
|
|
| start_time = time.monotonic() |
| |
| try: |
| input_string = payload.parameter |
| selected_model = payload.model |
| provided_secret = payload.secret |
|
|
| |
| if not verify_secret(provided_secret): |
| print(f"Tentative d'accès avec un secret invalide: '{provided_secret[:10]}...'") |
| raise HTTPException( |
| status_code=401, |
| detail="Secret invalide. Accès non autorisé." |
| ) |
|
|
| print(f"Secret vérifié avec succès. Endpoint /v1/mock appelé avec: '{input_string}'") |
|
|
| if input_string is None: |
| raise HTTPException(status_code=400, detail="Le paramètre 'parameter' est manquant.") |
| |
| if not manager.active_connections: |
| raise HTTPException(status_code=503, detail="Aucun client WebSocket n'est connecté.") |
|
|
| |
| message_data = { |
| "prompt": input_string, |
| "model": selected_model |
| } |
| |
| |
| response_future = await manager.broadcast(json.dumps(message_data)) |
| |
| |
| print("Envoi du message au client WebSocket...") |
| response_future = await manager.broadcast(input_string) |
|
|
| if response_future is None: |
| raise HTTPException(status_code=500, detail="Échec de la diffusion du message.") |
|
|
| try: |
| |
| websocket_response = await asyncio.wait_for(response_future, timeout=60.0) |
| print(f"Réponse reçue du WebSocket: '{websocket_response}'") |
| end_time = time.monotonic() |
| duration = end_time - start_time |
| print(f"Requête complétée en {duration:.2f} secondes.") |
| return { |
| "response_from_client": websocket_response, |
| "completion_time_in_seconds": round(duration, 2) |
| } |
|
|
|
|
|
|
|
|
| except asyncio.TimeoutError: |
| print("Timeout: Aucune réponse du client WebSocket.") |
| raise HTTPException(status_code=408, detail="Timeout: Le client n'a pas répondu à temps.") |
|
|
| except HTTPException: |
| |
| raise |
| except Exception as e: |
| print(f"Erreur dans /v1/mock: {e}") |
| raise HTTPException(status_code=500, detail=f"Une erreur interne est survenue: {str(e)}") |
|
|
| @app.websocket("/ws") |
| async def websocket_endpoint(websocket: WebSocket): |
| """Gère la communication WebSocket avec le client.""" |
| await manager.connect(websocket) |
| try: |
| while True: |
| |
| data = await websocket.receive_text() |
| print(f"Message reçu du client: '{data}'") |
|
|
| |
| client_id = str(id(websocket)) |
| if client_id in manager.response_futures: |
| manager.response_futures[client_id].set_result(data) |
| del manager.response_futures[client_id] |
|
|
| except WebSocketDisconnect: |
| manager.disconnect(websocket) |
| print("Client déconnecté.") |
| except Exception as e: |
| print(f"Erreur dans le WebSocket: {e}") |
| manager.disconnect(websocket) |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |