Bas95's picture
Fix: orchestrator callback mismatch
08e87fb verified
import asyncio
import json
import logging
import os
import uuid
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from dotenv import load_dotenv
from fastapi import BackgroundTasks, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from .models import SearchProgress, SearchResponse, TripResult, TripSearchRequest
from .orchestrator import run_trip_search
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# In-memory store: search_id -> {"status": str, "results": list | None, "queue": asyncio.Queue}
searches: dict[str, dict] = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
logger.warning(
"GEMINI_API_KEY is not set. Trip search will fail until it is configured."
)
else:
logger.info("GEMINI_API_KEY found — ready to serve requests.")
yield
app = FastAPI(title="Tripplanner API", version="0.1.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173", "http://localhost:5174",
"http://localhost:5175", "http://localhost:3000",
"*"], # Vercel preview URLs
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Background task wrapper
# ---------------------------------------------------------------------------
async def _run_search(search_id: str, request: TripSearchRequest) -> None:
"""Run the trip search and feed progress into the per-search queue."""
queue: asyncio.Queue = searches[search_id]["queue"]
def progress_callback(p: SearchProgress) -> None:
queue.put_nowait(p)
try:
results = await run_trip_search(request, progress_callback)
searches[search_id]["results"] = results
# Push the results list directly so the SSE handler emits the "results" event
await queue.put(results)
except Exception as exc:
logger.exception("Error in background trip search %s", search_id)
await queue.put(SearchProgress(step="error", message=str(exc), progress=0))
searches[search_id]["status"] = "error"
return
searches[search_id]["status"] = "done"
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/api/health")
async def health() -> dict:
return {"status": "ok"}
@app.post("/api/session/upload")
async def upload_session(request: Request) -> dict:
"""Accept a Playwright storage-state JSON from the local session_setup.py script."""
secret = os.getenv("SESSION_UPLOAD_SECRET", "")
if secret and request.headers.get("X-Session-Secret") != secret:
raise HTTPException(status_code=403, detail="Invalid session secret")
body = await request.body()
session_path = os.path.join(
os.getenv("BROWSER_USER_DATA_DIR", "./browser_data"), "sessions.json"
)
os.makedirs(os.path.dirname(session_path), exist_ok=True)
with open(session_path, "wb") as f:
f.write(body)
logger.info("Session state uploaded (%d bytes)", len(body))
return {"status": "ok", "bytes": len(body)}
@app.post("/api/search", response_model=dict)
async def start_search(
request: TripSearchRequest, background_tasks: BackgroundTasks
) -> dict:
search_id = str(uuid.uuid4())
searches[search_id] = {
"status": "running",
"results": None,
"queue": asyncio.Queue(),
}
background_tasks.add_task(_run_search, search_id, request)
return {"search_id": search_id}
@app.get("/api/search/{search_id}")
async def get_search(search_id: str) -> dict:
entry = searches.get(search_id)
if entry is None:
raise HTTPException(status_code=404, detail="Search not found")
return {"status": entry["status"], "results": entry["results"]}
@app.get("/api/search/{search_id}/stream")
async def stream_search(search_id: str) -> StreamingResponse:
entry = searches.get(search_id)
if entry is None:
raise HTTPException(status_code=404, detail="Search not found")
async def event_generator() -> AsyncGenerator[str, None]:
queue: asyncio.Queue = entry["queue"]
while True:
try:
item = await asyncio.wait_for(queue.get(), timeout=30.0)
except asyncio.TimeoutError:
# Send a keep-alive comment so the connection doesn't drop
yield ": keep-alive\n\n"
continue
if isinstance(item, SearchProgress):
payload = item.model_dump()
yield f"data: {json.dumps(payload)}\n\n"
if item.step == "done":
# Emit the final results as a separate "results" event
results = entry.get("results")
if results is not None:
results_payload = [
r.model_dump() if isinstance(r, TripResult) else r
for r in results
]
yield f"event: results\ndata: {json.dumps(results_payload)}\n\n"
break
if item.step == "error":
break
elif isinstance(item, list):
# The orchestrator may push the final TripResult list directly
entry["results"] = item
results_payload = [
r.model_dump() if isinstance(r, TripResult) else r
for r in item
]
yield f"event: results\ndata: {json.dumps(results_payload)}\n\n"
break
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
},
)