Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import uuid | |
| from contextlib import asynccontextmanager | |
| from typing import AsyncGenerator | |
| from dotenv import load_dotenv | |
| from fastapi import BackgroundTasks, FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from .models import SearchProgress, SearchResponse, TripResult, TripSearchRequest | |
| from .orchestrator import run_trip_search | |
| load_dotenv() | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # In-memory store: search_id -> {"status": str, "results": list | None, "queue": asyncio.Queue} | |
| searches: dict[str, dict] = {} | |
| async def lifespan(app: FastAPI): | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| logger.warning( | |
| "GEMINI_API_KEY is not set. Trip search will fail until it is configured." | |
| ) | |
| else: | |
| logger.info("GEMINI_API_KEY found — ready to serve requests.") | |
| yield | |
| app = FastAPI(title="Tripplanner API", version="0.1.0", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["http://localhost:5173", "http://localhost:5174", | |
| "http://localhost:5175", "http://localhost:3000", | |
| "*"], # Vercel preview URLs | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Background task wrapper | |
| # --------------------------------------------------------------------------- | |
| async def _run_search(search_id: str, request: TripSearchRequest) -> None: | |
| """Run the trip search and feed progress into the per-search queue.""" | |
| queue: asyncio.Queue = searches[search_id]["queue"] | |
| def progress_callback(p: SearchProgress) -> None: | |
| queue.put_nowait(p) | |
| try: | |
| results = await run_trip_search(request, progress_callback) | |
| searches[search_id]["results"] = results | |
| # Push the results list directly so the SSE handler emits the "results" event | |
| await queue.put(results) | |
| except Exception as exc: | |
| logger.exception("Error in background trip search %s", search_id) | |
| await queue.put(SearchProgress(step="error", message=str(exc), progress=0)) | |
| searches[search_id]["status"] = "error" | |
| return | |
| searches[search_id]["status"] = "done" | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| async def health() -> dict: | |
| return {"status": "ok"} | |
| async def upload_session(request: Request) -> dict: | |
| """Accept a Playwright storage-state JSON from the local session_setup.py script.""" | |
| secret = os.getenv("SESSION_UPLOAD_SECRET", "") | |
| if secret and request.headers.get("X-Session-Secret") != secret: | |
| raise HTTPException(status_code=403, detail="Invalid session secret") | |
| body = await request.body() | |
| session_path = os.path.join( | |
| os.getenv("BROWSER_USER_DATA_DIR", "./browser_data"), "sessions.json" | |
| ) | |
| os.makedirs(os.path.dirname(session_path), exist_ok=True) | |
| with open(session_path, "wb") as f: | |
| f.write(body) | |
| logger.info("Session state uploaded (%d bytes)", len(body)) | |
| return {"status": "ok", "bytes": len(body)} | |
| async def start_search( | |
| request: TripSearchRequest, background_tasks: BackgroundTasks | |
| ) -> dict: | |
| search_id = str(uuid.uuid4()) | |
| searches[search_id] = { | |
| "status": "running", | |
| "results": None, | |
| "queue": asyncio.Queue(), | |
| } | |
| background_tasks.add_task(_run_search, search_id, request) | |
| return {"search_id": search_id} | |
| async def get_search(search_id: str) -> dict: | |
| entry = searches.get(search_id) | |
| if entry is None: | |
| raise HTTPException(status_code=404, detail="Search not found") | |
| return {"status": entry["status"], "results": entry["results"]} | |
| async def stream_search(search_id: str) -> StreamingResponse: | |
| entry = searches.get(search_id) | |
| if entry is None: | |
| raise HTTPException(status_code=404, detail="Search not found") | |
| async def event_generator() -> AsyncGenerator[str, None]: | |
| queue: asyncio.Queue = entry["queue"] | |
| while True: | |
| try: | |
| item = await asyncio.wait_for(queue.get(), timeout=30.0) | |
| except asyncio.TimeoutError: | |
| # Send a keep-alive comment so the connection doesn't drop | |
| yield ": keep-alive\n\n" | |
| continue | |
| if isinstance(item, SearchProgress): | |
| payload = item.model_dump() | |
| yield f"data: {json.dumps(payload)}\n\n" | |
| if item.step == "done": | |
| # Emit the final results as a separate "results" event | |
| results = entry.get("results") | |
| if results is not None: | |
| results_payload = [ | |
| r.model_dump() if isinstance(r, TripResult) else r | |
| for r in results | |
| ] | |
| yield f"event: results\ndata: {json.dumps(results_payload)}\n\n" | |
| break | |
| if item.step == "error": | |
| break | |
| elif isinstance(item, list): | |
| # The orchestrator may push the final TripResult list directly | |
| entry["results"] = item | |
| results_payload = [ | |
| r.model_dump() if isinstance(r, TripResult) else r | |
| for r in item | |
| ] | |
| yield f"event: results\ndata: {json.dumps(results_payload)}\n\n" | |
| break | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "X-Accel-Buffering": "no", | |
| "Connection": "keep-alive", | |
| }, | |
| ) | |