Subhadip007 commited on
Commit
5c095ca
·
1 Parent(s): 2671aea

feat: FastAPI backend complete

Browse files

- FastAPI application with lifespan startup (models pre-loaded)
- POST /query: full RAG pipeline over HTTP with Pydantic validation
- GET /health: system status with vector DB and BM25 index sizes
- CORS middleware: browser frontends can call the API
- asyncio.to_thread: CPU-bound RAG runs without blocking event loop
- Auto-generated Swagger UI at /docs (OAS 3.1)
- Warm query latency: ~3s after first request warms models

Endpoints:
GET / API info
GET /health System health check
POST /query Research paper Q&A with citations

Files changed (5) hide show
  1. run_api.py +27 -0
  2. src/api/__init__.py +0 -0
  3. src/api/main.py +237 -0
  4. src/api/schemas.py +85 -0
  5. test_api.py +37 -0
run_api.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Start the ResearchPilot FastAPI server.
3
+
4
+ Run from project root:
5
+ python run_api.py
6
+
7
+ Then visit:
8
+ http://localhost:8000/docs <- Interactive API documentation
9
+ http://localhost:8000/health <- Health check
10
+ http://localhost:8000/ <- API info
11
+ """
12
+
13
+ import uvicorn
14
+ from config.settings import API_HOST, API_PORT, API_RELOAD
15
+
16
+ if __name__ == "__main__":
17
+ print("Starting ResearchPilot API...")
18
+ print(f"API docs: http://localhost:{API_PORT}/docs")
19
+ print(f"Health: http://localhost:{API_PORT}/health")
20
+
21
+ uvicorn.run(
22
+ "src.api.main:app",
23
+ host = API_HOST,
24
+ port = API_PORT,
25
+ reload = API_RELOAD, # Auto-restart on code changes (dev only)
26
+ workers = 1, # Single worker for dev (no GPU sharing issues)
27
+ )
src/api/__init__.py ADDED
File without changes
src/api/main.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ResearchPilot FastAPI application.
3
+
4
+ STARTUP BEHAVIOR:
5
+ When the server starts, it loads ALL models into memory:
6
+ - BGE embedding model (~110MB)
7
+ - Cross-encoder re-ranker (~80MB)
8
+ - BM25 index (~40MB)
9
+ - Qdrant connection
10
+
11
+ This takes ~15 seconds once, then every request is fast.
12
+ This is called "warm start" - the model is always ready.
13
+
14
+ Without this, the first request after server restart
15
+ would take 20+ seconds. Unacceptable for production.
16
+
17
+ LIFESPAN PATTERN:
18
+ FastAPI's lifespan context manager runs code at startup
19
+ and shutdown. We use it to initialize the RAG pipeline
20
+ once and store it in app.state for all requests to share.
21
+ """
22
+
23
+ import asyncio
24
+ import time
25
+ from contextlib import asynccontextmanager
26
+
27
+ from fastapi import FastAPI, HTTPException, Request
28
+ from fastapi.middleware.cors import CORSMiddleware
29
+ from fastapi.responses import JSONResponse
30
+
31
+ from src.api.schemas import (
32
+ QueryRequest,
33
+ QueryResponse,
34
+ CitationSchema,
35
+ HealthResponse,
36
+ ErrorResponse,
37
+ )
38
+ from src.rag.pipeline import RAGPipeline
39
+ from src.utils.logger import setup_logger, get_logger
40
+
41
+
42
+ setup_logger()
43
+ logger = get_logger(__name__)
44
+
45
+
46
+ # ---------------------------------------------------------
47
+ # LIFESPAN - runs at startup and shutdown
48
+ # ---------------------------------------------------------
49
+
50
+ @asynccontextmanager
51
+ async def lifespan(app: FastAPI):
52
+ """
53
+ Initialize resources at startup, clean up at shutdown.
54
+
55
+ The 'yield' separates startup (before) from shutdown (after).
56
+ Everything before yield runs when server starts.
57
+ Everything after yield runs when server shuts down.
58
+ """
59
+
60
+ # --------------- STARTUP ---------------
61
+ logger.info("ResearchPilot API starting up...")
62
+ start = time.time()
63
+
64
+ # Initialize RAG pipeline - loads all models into memory
65
+ # We store it on app.state so all request handlers can access it
66
+ app.state.rag_pipeline = RAGPipeline()
67
+
68
+ elapsed = time.time() - start
69
+ logger.info(f"API ready in {elapsed:.1f}s")
70
+
71
+ yield # Server is now running and handling requests
72
+
73
+ # --------------- SHUTDOWN ---------------
74
+ logger.info("ResearchPilot API shutting down...")
75
+
76
+
77
+ # ---------------------------------------------------------
78
+ # APP INITIALIZATION
79
+ # ---------------------------------------------------------
80
+
81
+ app = FastAPI(
82
+ title = "ResearchPilot API",
83
+ description = "Production RAG system for ML research paper Q&A",
84
+ version = "1.0.0",
85
+ lifespan = lifespan,
86
+ docs_url = "/docs", # Swagger UI at http://localhost:8000/docs
87
+ redoc_url = "/redoc", # ReDoc at http://localhost:8000/redoc
88
+ )
89
+
90
+ # CORS middleware — allows browser-based frontends to call this API
91
+ # Without this, a browser on localhost:3000 cannot call localhost:8000
92
+ app.add_middleware(
93
+ CORSMiddleware,
94
+ allow_origins = ["*"], # In production, restrict to your domain
95
+ allow_methods = ["*"],
96
+ allow_headers = ["*"],
97
+ )
98
+
99
+ # ---------------------------------------------------------
100
+ # EXCEPTION HANDLER
101
+ # ---------------------------------------------------------
102
+
103
+ @app.exception_handler(Exception)
104
+ async def global_exception_handler(request: Request, exc: Exception):
105
+ """
106
+ Catch any unhandled exception and return a clean JSON error.
107
+ Without this, FastAPI returns a raw 500 error with no detail.
108
+ """
109
+ logger.error(f"Unhandled exception on {request.url}: {exc}")
110
+ return JSONResponse(
111
+ status_code = 500,
112
+ content = {
113
+ "error": "Internal server error",
114
+ "detail": str(exc),
115
+ "code": 500,
116
+ }
117
+ )
118
+
119
+
120
+ # ---------------------------------------------------------
121
+ # ROUTES
122
+ # ---------------------------------------------------------
123
+
124
+ @app.get(
125
+ "/health",
126
+ response_model = HealthResponse,
127
+ summary = "Health check",
128
+ tags = ["System"],
129
+ )
130
+ async def health_check(request: Request) -> HealthResponse:
131
+ """
132
+ Returns system health status.
133
+ Used by deployment platforms to verify the service is running.
134
+ Also useful for debugging - shows database sizes.
135
+ """
136
+ pipeline = request.app.state.rag_pipeline
137
+
138
+ # Get Qdrant collection size
139
+ qdrant_size = pipeline.retriever.hybrid_retriever.qdrant.get_collection_size()
140
+
141
+ # Get BM25 index size
142
+ bm25_size = len(pipeline.retriever.hybrid_retriever.bm25.chunk_ids)
143
+
144
+ return HealthResponse(
145
+ status = "healthy",
146
+ model = "llama-3.3-70b-versatile",
147
+ vector_db_size = qdrant_size,
148
+ bm25_index_size = bm25_size,
149
+ version = "1.0.0",
150
+ )
151
+
152
+
153
+ @app.post(
154
+ "/query",
155
+ response_model = QueryResponse,
156
+ summary = "Query research papers",
157
+ tags = ["RAG"],
158
+ )
159
+ async def query_papers(
160
+ request: Request,
161
+ query_input: QueryRequest,
162
+ ) -> QueryResponse:
163
+ """
164
+ Submit a natural language question about ML research.
165
+
166
+ The system retrieves relevant paper excerpts and generates
167
+ a grounded answer with citations.
168
+
169
+ - **question**: Your research question (3-500 characters)
170
+ - **top_k**: Number of paper chunks to retrieve (1-20, default 5)
171
+ - **filter_category**: Filter by ArXiv category (e.g. cs.LG)
172
+ - **filter_year_gte**: Only include papers from this year onwards
173
+ """
174
+ pipeline = request.app.state.rag_pipeline
175
+
176
+ logger.info(
177
+ f"Query received: '{query_input.question[:60]}' "
178
+ f"[top_k={query_input.top_k}]"
179
+ )
180
+
181
+ # Run the RAG pipeline in a thread pool
182
+ # WHY asyncio.to_thread:
183
+ # Our RAG pipeline is CPU-bound (not async).
184
+ # Running it directly in an async handler would BLOCK
185
+ # the entire FastAPI event loop - no other requests
186
+ # could be processed while one query is running.
187
+ # asyncio.to_thread runs it in a separate thread,
188
+ # keeping the event loop free for other requests.
189
+ try:
190
+ response = await asyncio.to_thread(
191
+ pipeline.query,
192
+ query_input.question,
193
+ query_input.top_k,
194
+ query_input.filter_category,
195
+ query_input.filter_year_gte,
196
+ )
197
+ except Exception as e:
198
+ logger.error(f"RAG pipeline error: {e}")
199
+ raise HTTPException(status_code=500, detail=str(e))
200
+
201
+ # Convert RAGResponse dataclass to API schema
202
+ citations = [
203
+ CitationSchema(
204
+ paper_id = c.get("paper_id", ""),
205
+ title = c.get("title", ""),
206
+ authors = c.get("authors", []),
207
+ published_date = c.get("published_date", ""),
208
+ arxiv_url = c.get("arxiv_url", ""),
209
+ )
210
+ for c in response.citations
211
+ ]
212
+
213
+ return QueryResponse(
214
+ answer = response.answer,
215
+ citations = citations,
216
+ query = response.query,
217
+ chunks_used = len(response.retrieved_chunks),
218
+ retrieval_time_ms = response.retrieval_time_ms,
219
+ generation_time_ms = response.generation_time_ms,
220
+ total_time_ms = response.total_time_ms,
221
+ has_context = response.has_context,
222
+ )
223
+
224
+
225
+ @app.get(
226
+ "/",
227
+ summary = "API root",
228
+ tags = ["System"],
229
+ )
230
+ async def root():
231
+ """API root - confirms service is running."""
232
+ return {
233
+ "service": "ResearchPilot API",
234
+ "version": "1.0.0",
235
+ "docs": "/docs",
236
+ "health": "/health",
237
+ }
src/api/schemas.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic schemas for API request and response validation.
3
+
4
+ WHY PYDANTIC SCHEMAS IN THE API LAYER:
5
+ FastAPI uses these to:
6
+ 1. Validate incoming requests (wrong types -> automatic 422 error)
7
+ 2. Serialize outgoing responses (Python objects -> JSON)
8
+ 3. Generate automatic API documentation (OpenAPI/Swagger)
9
+
10
+ You get input validation AND documentation for free.
11
+ """
12
+
13
+ from pydantic import BaseModel, Field
14
+ from typing import Optional
15
+
16
+
17
+
18
+ class QueryRequest(BaseModel):
19
+ """
20
+ Schema for POST /query request body.
21
+
22
+ Field() lets us add validation constraints and documentation.
23
+ """
24
+ question: str = Field(
25
+ ..., # ... means required
26
+ min_length = 3,
27
+ max_length = 500,
28
+ description = "Research question to answer",
29
+ examples = ["How does LoRA reduce trainable parameters?"]
30
+ )
31
+ top_k: int = Field(
32
+ default = 5,
33
+ ge = 1, # ge = greater than or equal
34
+ le = 20,
35
+ description = "Number of chunks to retrieve"
36
+ )
37
+ filter_category: Optional[str] = Field(
38
+ default = None,
39
+ description = "ArXiv category filter, e.g. 'cs.LG'",
40
+ example = ["cs.LG"]
41
+ )
42
+ filter_year_gte: Optional[int] = Field(
43
+ default = None,
44
+ ge = 2020,
45
+ le = 2030,
46
+ description = "Only include papers from this year onwards",
47
+ example = [2024]
48
+ )
49
+
50
+
51
+ class CitationSchema(BaseModel):
52
+ """A single cited paper."""
53
+ paper_id: str
54
+ title: str
55
+ authors: list[str]
56
+ published_date: str
57
+ arxiv_url: str
58
+
59
+
60
+ class QueryResponse(BaseModel):
61
+ """Schema for POST /query response."""
62
+ answer: str
63
+ citations: list[CitationSchema]
64
+ query: str
65
+ chunks_used: int
66
+ retrieval_time_ms: float
67
+ generation_time_ms: float
68
+ total_time_ms: float
69
+ has_context: bool
70
+
71
+
72
+ class HealthResponse(BaseModel):
73
+ """Schema for GET /health response."""
74
+ status: str
75
+ model: str
76
+ vector_db_size: int
77
+ bm25_index_size: int
78
+ version: str = "1.0.0"
79
+
80
+
81
+ class ErrorResponse(BaseModel):
82
+ """Schema for error responses."""
83
+ error: str
84
+ detail: str
85
+ code: int
test_api.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run this in a SEPARATE terminal while run_api.py is running
2
+ import requests
3
+ import json
4
+
5
+ BASE_URL = "http://localhost:8000"
6
+
7
+ # Test 1: Health check
8
+ print("Testing /health...")
9
+ r = requests.get(f"{BASE_URL}/health")
10
+ print(json.dumps(r.json(), indent=2))
11
+
12
+ # Test 2: Query
13
+ print("\nTesting /query...")
14
+ payload = {
15
+ "question": "What is LoRA and how does it work?",
16
+ "top_k": 5
17
+ }
18
+ r = requests.post(f"{BASE_URL}/query", json=payload)
19
+ data = r.json()
20
+
21
+ print(f"Answer: {data['answer'][:300]}...")
22
+ print(f"\nCitations: {len(data['citations'])}")
23
+ for c in data['citations']:
24
+ print(f" - {c['paper_id']}: {c['title'][:50]}...")
25
+ print(f"\nTotal time: {data['total_time_ms']:.0f}ms")
26
+
27
+ # Test 3: Filtered query
28
+ print("\nTesting /query with filter...")
29
+ payload = {
30
+ "question": "graph neural network applications",
31
+ "top_k": 3,
32
+ "filter_year_gte": 2026
33
+ }
34
+ r = requests.post(f"{BASE_URL}/query", json=payload)
35
+ data = r.json()
36
+ print(f"Answer: {data['answer'][:200]}...")
37
+ print(f"Citations: {len(data['citations'])}")