import asyncio import json import logging import os import uuid from contextlib import asynccontextmanager from typing import AsyncGenerator from dotenv import load_dotenv from fastapi import BackgroundTasks, FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from .models import SearchProgress, SearchResponse, TripResult, TripSearchRequest from .orchestrator import run_trip_search load_dotenv() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # In-memory store: search_id -> {"status": str, "results": list | None, "queue": asyncio.Queue} searches: dict[str, dict] = {} @asynccontextmanager async def lifespan(app: FastAPI): api_key = os.getenv("GEMINI_API_KEY") if not api_key: logger.warning( "GEMINI_API_KEY is not set. Trip search will fail until it is configured." ) else: logger.info("GEMINI_API_KEY found — ready to serve requests.") yield app = FastAPI(title="Tripplanner API", version="0.1.0", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:5173", "http://localhost:5174", "http://localhost:5175", "http://localhost:3000", "*"], # Vercel preview URLs allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --------------------------------------------------------------------------- # Background task wrapper # --------------------------------------------------------------------------- async def _run_search(search_id: str, request: TripSearchRequest) -> None: """Run the trip search and feed progress into the per-search queue.""" queue: asyncio.Queue = searches[search_id]["queue"] def progress_callback(p: SearchProgress) -> None: queue.put_nowait(p) try: results = await run_trip_search(request, progress_callback) searches[search_id]["results"] = results # Push the results list directly so the SSE handler emits the "results" event await queue.put(results) except Exception as exc: logger.exception("Error in background trip search %s", search_id) await queue.put(SearchProgress(step="error", message=str(exc), progress=0)) searches[search_id]["status"] = "error" return searches[search_id]["status"] = "done" # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @app.get("/api/health") async def health() -> dict: return {"status": "ok"} @app.post("/api/session/upload") async def upload_session(request: Request) -> dict: """Accept a Playwright storage-state JSON from the local session_setup.py script.""" secret = os.getenv("SESSION_UPLOAD_SECRET", "") if secret and request.headers.get("X-Session-Secret") != secret: raise HTTPException(status_code=403, detail="Invalid session secret") body = await request.body() session_path = os.path.join( os.getenv("BROWSER_USER_DATA_DIR", "./browser_data"), "sessions.json" ) os.makedirs(os.path.dirname(session_path), exist_ok=True) with open(session_path, "wb") as f: f.write(body) logger.info("Session state uploaded (%d bytes)", len(body)) return {"status": "ok", "bytes": len(body)} @app.post("/api/search", response_model=dict) async def start_search( request: TripSearchRequest, background_tasks: BackgroundTasks ) -> dict: search_id = str(uuid.uuid4()) searches[search_id] = { "status": "running", "results": None, "queue": asyncio.Queue(), } background_tasks.add_task(_run_search, search_id, request) return {"search_id": search_id} @app.get("/api/search/{search_id}") async def get_search(search_id: str) -> dict: entry = searches.get(search_id) if entry is None: raise HTTPException(status_code=404, detail="Search not found") return {"status": entry["status"], "results": entry["results"]} @app.get("/api/search/{search_id}/stream") async def stream_search(search_id: str) -> StreamingResponse: entry = searches.get(search_id) if entry is None: raise HTTPException(status_code=404, detail="Search not found") async def event_generator() -> AsyncGenerator[str, None]: queue: asyncio.Queue = entry["queue"] while True: try: item = await asyncio.wait_for(queue.get(), timeout=30.0) except asyncio.TimeoutError: # Send a keep-alive comment so the connection doesn't drop yield ": keep-alive\n\n" continue if isinstance(item, SearchProgress): payload = item.model_dump() yield f"data: {json.dumps(payload)}\n\n" if item.step == "done": # Emit the final results as a separate "results" event results = entry.get("results") if results is not None: results_payload = [ r.model_dump() if isinstance(r, TripResult) else r for r in results ] yield f"event: results\ndata: {json.dumps(results_payload)}\n\n" break if item.step == "error": break elif isinstance(item, list): # The orchestrator may push the final TripResult list directly entry["results"] = item results_payload = [ r.model_dump() if isinstance(r, TripResult) else r for r in item ] yield f"event: results\ndata: {json.dumps(results_payload)}\n\n" break return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive", }, )