Spaces:
Sleeping
Sleeping
File size: 6,168 Bytes
649703e 08e87fb 649703e 08e87fb 649703e 08e87fb 649703e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | import asyncio
import json
import logging
import os
import uuid
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from dotenv import load_dotenv
from fastapi import BackgroundTasks, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from .models import SearchProgress, SearchResponse, TripResult, TripSearchRequest
from .orchestrator import run_trip_search
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# In-memory store: search_id -> {"status": str, "results": list | None, "queue": asyncio.Queue}
searches: dict[str, dict] = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
logger.warning(
"GEMINI_API_KEY is not set. Trip search will fail until it is configured."
)
else:
logger.info("GEMINI_API_KEY found — ready to serve requests.")
yield
app = FastAPI(title="Tripplanner API", version="0.1.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173", "http://localhost:5174",
"http://localhost:5175", "http://localhost:3000",
"*"], # Vercel preview URLs
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Background task wrapper
# ---------------------------------------------------------------------------
async def _run_search(search_id: str, request: TripSearchRequest) -> None:
"""Run the trip search and feed progress into the per-search queue."""
queue: asyncio.Queue = searches[search_id]["queue"]
def progress_callback(p: SearchProgress) -> None:
queue.put_nowait(p)
try:
results = await run_trip_search(request, progress_callback)
searches[search_id]["results"] = results
# Push the results list directly so the SSE handler emits the "results" event
await queue.put(results)
except Exception as exc:
logger.exception("Error in background trip search %s", search_id)
await queue.put(SearchProgress(step="error", message=str(exc), progress=0))
searches[search_id]["status"] = "error"
return
searches[search_id]["status"] = "done"
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/api/health")
async def health() -> dict:
return {"status": "ok"}
@app.post("/api/session/upload")
async def upload_session(request: Request) -> dict:
"""Accept a Playwright storage-state JSON from the local session_setup.py script."""
secret = os.getenv("SESSION_UPLOAD_SECRET", "")
if secret and request.headers.get("X-Session-Secret") != secret:
raise HTTPException(status_code=403, detail="Invalid session secret")
body = await request.body()
session_path = os.path.join(
os.getenv("BROWSER_USER_DATA_DIR", "./browser_data"), "sessions.json"
)
os.makedirs(os.path.dirname(session_path), exist_ok=True)
with open(session_path, "wb") as f:
f.write(body)
logger.info("Session state uploaded (%d bytes)", len(body))
return {"status": "ok", "bytes": len(body)}
@app.post("/api/search", response_model=dict)
async def start_search(
request: TripSearchRequest, background_tasks: BackgroundTasks
) -> dict:
search_id = str(uuid.uuid4())
searches[search_id] = {
"status": "running",
"results": None,
"queue": asyncio.Queue(),
}
background_tasks.add_task(_run_search, search_id, request)
return {"search_id": search_id}
@app.get("/api/search/{search_id}")
async def get_search(search_id: str) -> dict:
entry = searches.get(search_id)
if entry is None:
raise HTTPException(status_code=404, detail="Search not found")
return {"status": entry["status"], "results": entry["results"]}
@app.get("/api/search/{search_id}/stream")
async def stream_search(search_id: str) -> StreamingResponse:
entry = searches.get(search_id)
if entry is None:
raise HTTPException(status_code=404, detail="Search not found")
async def event_generator() -> AsyncGenerator[str, None]:
queue: asyncio.Queue = entry["queue"]
while True:
try:
item = await asyncio.wait_for(queue.get(), timeout=30.0)
except asyncio.TimeoutError:
# Send a keep-alive comment so the connection doesn't drop
yield ": keep-alive\n\n"
continue
if isinstance(item, SearchProgress):
payload = item.model_dump()
yield f"data: {json.dumps(payload)}\n\n"
if item.step == "done":
# Emit the final results as a separate "results" event
results = entry.get("results")
if results is not None:
results_payload = [
r.model_dump() if isinstance(r, TripResult) else r
for r in results
]
yield f"event: results\ndata: {json.dumps(results_payload)}\n\n"
break
if item.step == "error":
break
elif isinstance(item, list):
# The orchestrator may push the final TripResult list directly
entry["results"] = item
results_payload = [
r.model_dump() if isinstance(r, TripResult) else r
for r in item
]
yield f"event: results\ndata: {json.dumps(results_payload)}\n\n"
break
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
},
)
|