| """ |
| SPARKNET RAG API Routes |
| Endpoints for RAG queries, search, and indexing management. |
| """ |
|
|
| from fastapi import APIRouter, HTTPException, Query, Depends |
| from fastapi.responses import StreamingResponse |
| from typing import List, Optional |
| from pathlib import Path |
| from datetime import datetime |
| import time |
| import json |
| import sys |
| import asyncio |
|
|
| |
| PROJECT_ROOT = Path(__file__).parent.parent.parent |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from api.schemas import ( |
| QueryRequest, RAGResponse, Citation, QueryPlan, QueryIntentType, |
| SearchRequest, SearchResponse, SearchResult, |
| StoreStatus, CollectionInfo |
| ) |
| from loguru import logger |
|
|
| router = APIRouter() |
|
|
| |
| _query_cache = {} |
| CACHE_TTL_SECONDS = 3600 |
|
|
|
|
| def get_cache_key(query: str, doc_ids: Optional[List[str]]) -> str: |
| """Generate cache key for query.""" |
| import hashlib |
| doc_str = ",".join(sorted(doc_ids)) if doc_ids else "all" |
| content = f"{query}:{doc_str}" |
| return hashlib.md5(content.encode()).hexdigest() |
|
|
|
|
| def get_cached_response(cache_key: str) -> Optional[RAGResponse]: |
| """Get cached response if valid.""" |
| if cache_key in _query_cache: |
| cached = _query_cache[cache_key] |
| if time.time() - cached["timestamp"] < CACHE_TTL_SECONDS: |
| response = cached["response"] |
| response.from_cache = True |
| return response |
| else: |
| del _query_cache[cache_key] |
| return None |
|
|
|
|
| def cache_response(cache_key: str, response: RAGResponse): |
| """Cache a query response.""" |
| _query_cache[cache_key] = { |
| "response": response, |
| "timestamp": time.time() |
| } |
| |
| if len(_query_cache) > 1000: |
| oldest_key = min(_query_cache, key=lambda k: _query_cache[k]["timestamp"]) |
| del _query_cache[oldest_key] |
|
|
|
|
| def _get_rag_system(): |
| """Get or initialize the RAG system.""" |
| try: |
| from src.rag.agentic.orchestrator import AgenticRAG, RAGConfig |
|
|
| config = RAGConfig( |
| model_name="llama3.2:latest", |
| max_revision_attempts=2, |
| retrieval_top_k=10, |
| final_top_k=5, |
| min_confidence=0.5, |
| ) |
| return AgenticRAG(config) |
| except Exception as e: |
| logger.error(f"Failed to initialize RAG system: {e}") |
| return None |
|
|
|
|
| @router.post("/query", response_model=RAGResponse) |
| async def query_documents(request: QueryRequest): |
| """ |
| Execute a RAG query across indexed documents. |
| |
| The query goes through the 5-agent pipeline: |
| 1. QueryPlanner - Intent classification and query decomposition |
| 2. Retriever - Hybrid dense+sparse search |
| 3. Reranker - Cross-encoder reranking with MMR |
| 4. Synthesizer - Answer generation with citations |
| 5. Critic - Hallucination detection and validation |
| """ |
| start_time = time.time() |
|
|
| |
| if request.use_cache: |
| cache_key = get_cache_key(request.query, request.doc_ids) |
| cached = get_cached_response(cache_key) |
| if cached: |
| cached.latency_ms = (time.time() - start_time) * 1000 |
| return cached |
|
|
| try: |
| |
| rag = _get_rag_system() |
| if not rag: |
| raise HTTPException(status_code=503, detail="RAG system not available") |
|
|
| |
| filters = {} |
| if request.doc_ids: |
| filters["document_id"] = {"$in": request.doc_ids} |
|
|
| |
| logger.info(f"Executing RAG query: {request.query[:50]}...") |
|
|
| result = rag.query( |
| query=request.query, |
| filters=filters if filters else None, |
| top_k=request.top_k, |
| ) |
|
|
| |
| citations = [] |
| for i, source in enumerate(result.get("sources", [])): |
| citations.append(Citation( |
| citation_id=i + 1, |
| doc_id=source.get("document_id", "unknown"), |
| document_name=source.get("filename", source.get("document_id", "unknown")), |
| chunk_id=source.get("chunk_id", f"chunk_{i}"), |
| chunk_text=source.get("text", "")[:300], |
| page_num=source.get("page_num"), |
| relevance_score=source.get("relevance_score", source.get("score", 0.0)), |
| bbox=source.get("bbox"), |
| )) |
|
|
| |
| query_plan = None |
| if "plan" in result: |
| plan = result["plan"] |
| query_plan = QueryPlan( |
| intent=QueryIntentType(plan.get("intent", "factoid").lower()), |
| sub_queries=plan.get("sub_queries", []), |
| keywords=plan.get("keywords", []), |
| strategy=plan.get("strategy", "hybrid"), |
| ) |
|
|
| response = RAGResponse( |
| query=request.query, |
| answer=result.get("answer", "I could not find an answer to your question."), |
| confidence=result.get("confidence", 0.0), |
| citations=citations, |
| source_count=len(citations), |
| query_plan=query_plan, |
| from_cache=False, |
| validation=result.get("validation"), |
| latency_ms=(time.time() - start_time) * 1000, |
| revision_count=result.get("revision_count", 0), |
| ) |
|
|
| |
| if request.use_cache and response.confidence >= request.min_confidence: |
| cache_key = get_cache_key(request.query, request.doc_ids) |
| cache_response(cache_key, response) |
|
|
| return response |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"RAG query failed: {e}") |
| raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}") |
|
|
|
|
| @router.post("/query/stream") |
| async def query_documents_stream(request: QueryRequest): |
| """ |
| Stream RAG response for real-time updates. |
| |
| Returns Server-Sent Events (SSE) with partial responses. |
| """ |
| async def generate(): |
| try: |
| |
| rag = _get_rag_system() |
| if not rag: |
| yield f"data: {json.dumps({'error': 'RAG system not available'})}\n\n" |
| return |
|
|
| |
| yield f"data: {json.dumps({'stage': 'planning', 'message': 'Analyzing query...'})}\n\n" |
| await asyncio.sleep(0.1) |
|
|
| |
| filters = {} |
| if request.doc_ids: |
| filters["document_id"] = {"$in": request.doc_ids} |
|
|
| |
| yield f"data: {json.dumps({'stage': 'retrieving', 'message': 'Searching documents...'})}\n\n" |
|
|
| |
| result = rag.query( |
| query=request.query, |
| filters=filters if filters else None, |
| top_k=request.top_k, |
| ) |
|
|
| |
| yield f"data: {json.dumps({'stage': 'sources', 'count': len(result.get('sources', []))})}\n\n" |
|
|
| |
| yield f"data: {json.dumps({'stage': 'synthesizing', 'message': 'Generating answer...'})}\n\n" |
|
|
| |
| answer = result.get("answer", "") |
| chunk_size = 50 |
| for i in range(0, len(answer), chunk_size): |
| chunk = answer[i:i+chunk_size] |
| yield f"data: {json.dumps({'stage': 'answer', 'chunk': chunk})}\n\n" |
| await asyncio.sleep(0.02) |
|
|
| |
| citations = [] |
| for i, source in enumerate(result.get("sources", [])): |
| citations.append({ |
| "citation_id": i + 1, |
| "doc_id": source.get("document_id", "unknown"), |
| "chunk_text": source.get("text", "")[:200], |
| "relevance_score": source.get("score", 0.0), |
| }) |
|
|
| final = { |
| "stage": "complete", |
| "confidence": result.get("confidence", 0.0), |
| "citations": citations, |
| "validation": result.get("validation"), |
| } |
| yield f"data: {json.dumps(final)}\n\n" |
|
|
| except Exception as e: |
| logger.error(f"Streaming query failed: {e}") |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" |
|
|
| return StreamingResponse( |
| generate(), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| } |
| ) |
|
|
|
|
| @router.post("/search", response_model=SearchResponse) |
| async def search_documents(request: SearchRequest): |
| """ |
| Semantic search across indexed documents. |
| |
| Returns matching chunks without answer synthesis. |
| """ |
| start_time = time.time() |
|
|
| try: |
| from src.rag.store import get_vector_store |
| from src.rag.embeddings import get_embedding_model |
|
|
| store = get_vector_store() |
| embeddings = get_embedding_model() |
|
|
| |
| query_embedding = embeddings.embed_query(request.query) |
|
|
| |
| where_filter = None |
| if request.doc_ids: |
| where_filter = {"document_id": {"$in": request.doc_ids}} |
|
|
| |
| results = store.similarity_search_with_score( |
| query_embedding=query_embedding, |
| k=request.top_k, |
| where=where_filter, |
| ) |
|
|
| |
| search_results = [] |
| for doc, score in results: |
| if score >= request.min_score: |
| search_results.append(SearchResult( |
| chunk_id=doc.metadata.get("chunk_id", "unknown"), |
| doc_id=doc.metadata.get("document_id", "unknown"), |
| document_name=doc.metadata.get("filename", "unknown"), |
| text=doc.page_content, |
| score=score, |
| page_num=doc.metadata.get("page_num"), |
| chunk_type=doc.metadata.get("chunk_type", "text"), |
| )) |
|
|
| return SearchResponse( |
| query=request.query, |
| total_results=len(search_results), |
| results=search_results, |
| latency_ms=(time.time() - start_time) * 1000, |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Search failed: {e}") |
| |
| return SearchResponse( |
| query=request.query, |
| total_results=0, |
| results=[], |
| latency_ms=(time.time() - start_time) * 1000, |
| ) |
|
|
|
|
| @router.get("/store/status", response_model=StoreStatus) |
| async def get_store_status(): |
| """Get vector store status and statistics.""" |
| try: |
| from src.rag.store import get_vector_store |
|
|
| store = get_vector_store() |
|
|
| |
| collection = store._collection |
| count = collection.count() |
|
|
| |
| all_metadata = collection.get(include=["metadatas"]) |
| doc_ids = set() |
| for meta in all_metadata.get("metadatas", []): |
| if meta and "document_id" in meta: |
| doc_ids.add(meta["document_id"]) |
|
|
| collections = [CollectionInfo( |
| name=store.collection_name, |
| document_count=len(doc_ids), |
| chunk_count=count, |
| embedding_dimension=store.embedding_dimension if hasattr(store, 'embedding_dimension') else 1024, |
| )] |
|
|
| return StoreStatus( |
| status="healthy", |
| collections=collections, |
| total_documents=len(doc_ids), |
| total_chunks=count, |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Store status check failed: {e}") |
| return StoreStatus( |
| status="error", |
| collections=[], |
| total_documents=0, |
| total_chunks=0, |
| ) |
|
|
|
|
| @router.delete("/store/collection/{collection_name}") |
| async def clear_collection(collection_name: str, confirm: bool = Query(False)): |
| """Clear a vector store collection (dangerous operation).""" |
| if not confirm: |
| raise HTTPException( |
| status_code=400, |
| detail="This operation will delete all data. Set confirm=true to proceed." |
| ) |
|
|
| try: |
| from src.rag.store import get_vector_store |
|
|
| store = get_vector_store() |
| if store.collection_name != collection_name: |
| raise HTTPException(status_code=404, detail=f"Collection not found: {collection_name}") |
|
|
| |
| store._collection.delete(where={}) |
|
|
| return {"status": "cleared", "collection": collection_name, "message": "Collection cleared successfully"} |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"Collection clear failed: {e}") |
| raise HTTPException(status_code=500, detail=f"Clear failed: {str(e)}") |
|
|
|
|
| @router.get("/cache/stats") |
| async def get_cache_stats(): |
| """Get query cache statistics.""" |
| current_time = time.time() |
| valid_entries = sum( |
| 1 for v in _query_cache.values() |
| if current_time - v["timestamp"] < CACHE_TTL_SECONDS |
| ) |
|
|
| return { |
| "total_entries": len(_query_cache), |
| "valid_entries": valid_entries, |
| "expired_entries": len(_query_cache) - valid_entries, |
| "ttl_seconds": CACHE_TTL_SECONDS, |
| } |
|
|
|
|
| @router.delete("/cache") |
| async def clear_cache(): |
| """Clear the query cache.""" |
| count = len(_query_cache) |
| _query_cache.clear() |
| return {"status": "cleared", "entries_removed": count} |
|
|