File size: 6,168 Bytes
649703e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08e87fb
 
 
 
649703e
08e87fb
 
 
 
649703e
 
08e87fb
649703e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import asyncio
import json
import logging
import os
import uuid
from contextlib import asynccontextmanager
from typing import AsyncGenerator

from dotenv import load_dotenv
from fastapi import BackgroundTasks, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse

from .models import SearchProgress, SearchResponse, TripResult, TripSearchRequest
from .orchestrator import run_trip_search

load_dotenv()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# In-memory store: search_id -> {"status": str, "results": list | None, "queue": asyncio.Queue}
searches: dict[str, dict] = {}


@asynccontextmanager
async def lifespan(app: FastAPI):
    api_key = os.getenv("GEMINI_API_KEY")
    if not api_key:
        logger.warning(
            "GEMINI_API_KEY is not set. Trip search will fail until it is configured."
        )
    else:
        logger.info("GEMINI_API_KEY found — ready to serve requests.")
    yield


app = FastAPI(title="Tripplanner API", version="0.1.0", lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:5173", "http://localhost:5174",
                   "http://localhost:5175", "http://localhost:3000",
                   "*"],  # Vercel preview URLs
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# ---------------------------------------------------------------------------
# Background task wrapper
# ---------------------------------------------------------------------------

async def _run_search(search_id: str, request: TripSearchRequest) -> None:
    """Run the trip search and feed progress into the per-search queue."""
    queue: asyncio.Queue = searches[search_id]["queue"]

    def progress_callback(p: SearchProgress) -> None:
        queue.put_nowait(p)

    try:
        results = await run_trip_search(request, progress_callback)
        searches[search_id]["results"] = results
        # Push the results list directly so the SSE handler emits the "results" event
        await queue.put(results)
    except Exception as exc:
        logger.exception("Error in background trip search %s", search_id)
        await queue.put(SearchProgress(step="error", message=str(exc), progress=0))
        searches[search_id]["status"] = "error"
        return

    searches[search_id]["status"] = "done"


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------

@app.get("/api/health")
async def health() -> dict:
    return {"status": "ok"}


@app.post("/api/session/upload")
async def upload_session(request: Request) -> dict:
    """Accept a Playwright storage-state JSON from the local session_setup.py script."""
    secret = os.getenv("SESSION_UPLOAD_SECRET", "")
    if secret and request.headers.get("X-Session-Secret") != secret:
        raise HTTPException(status_code=403, detail="Invalid session secret")
    body = await request.body()
    session_path = os.path.join(
        os.getenv("BROWSER_USER_DATA_DIR", "./browser_data"), "sessions.json"
    )
    os.makedirs(os.path.dirname(session_path), exist_ok=True)
    with open(session_path, "wb") as f:
        f.write(body)
    logger.info("Session state uploaded (%d bytes)", len(body))
    return {"status": "ok", "bytes": len(body)}


@app.post("/api/search", response_model=dict)
async def start_search(
    request: TripSearchRequest, background_tasks: BackgroundTasks
) -> dict:
    search_id = str(uuid.uuid4())
    searches[search_id] = {
        "status": "running",
        "results": None,
        "queue": asyncio.Queue(),
    }
    background_tasks.add_task(_run_search, search_id, request)
    return {"search_id": search_id}


@app.get("/api/search/{search_id}")
async def get_search(search_id: str) -> dict:
    entry = searches.get(search_id)
    if entry is None:
        raise HTTPException(status_code=404, detail="Search not found")
    return {"status": entry["status"], "results": entry["results"]}


@app.get("/api/search/{search_id}/stream")
async def stream_search(search_id: str) -> StreamingResponse:
    entry = searches.get(search_id)
    if entry is None:
        raise HTTPException(status_code=404, detail="Search not found")

    async def event_generator() -> AsyncGenerator[str, None]:
        queue: asyncio.Queue = entry["queue"]
        while True:
            try:
                item = await asyncio.wait_for(queue.get(), timeout=30.0)
            except asyncio.TimeoutError:
                # Send a keep-alive comment so the connection doesn't drop
                yield ": keep-alive\n\n"
                continue

            if isinstance(item, SearchProgress):
                payload = item.model_dump()
                yield f"data: {json.dumps(payload)}\n\n"

                if item.step == "done":
                    # Emit the final results as a separate "results" event
                    results = entry.get("results")
                    if results is not None:
                        results_payload = [
                            r.model_dump() if isinstance(r, TripResult) else r
                            for r in results
                        ]
                        yield f"event: results\ndata: {json.dumps(results_payload)}\n\n"
                    break

                if item.step == "error":
                    break

            elif isinstance(item, list):
                # The orchestrator may push the final TripResult list directly
                entry["results"] = item
                results_payload = [
                    r.model_dump() if isinstance(r, TripResult) else r
                    for r in item
                ]
                yield f"event: results\ndata: {json.dumps(results_payload)}\n\n"
                break

    return StreamingResponse(
        event_generator(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "X-Accel-Buffering": "no",
            "Connection": "keep-alive",
        },
    )