import json from typing import Any import requests from langchain_groq import ChatGroq from langchain_text_splitters import RecursiveCharacterTextSplitter from sqlalchemy import delete, func, select from sqlalchemy.orm import Session from app.config import get_settings from app.models import DocumentChunk class JinaEmbeddings: def __init__(self, *, api_key: str, base_url: str, model: str, dimensions: int) -> None: self.api_key = api_key self.base_url = base_url self.model = model self.dimensions = dimensions def embed_documents(self, texts: list[str]) -> list[list[float]]: return self._embed(texts=texts, task="retrieval.passage") def embed_query(self, text: str) -> list[float]: vectors = self._embed(texts=[text], task="retrieval.query") return vectors[0] if vectors else [0.0] * self.dimensions def _embed(self, *, texts: list[str], task: str) -> list[list[float]]: if not texts: return [] response = requests.post( self.base_url, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", }, json={ "model": self.model, "task": task, "embedding_type": "float", "normalized": True, "input": texts, }, timeout=60, ) response.raise_for_status() data = response.json().get("data", []) vectors = [row.get("embedding", []) for row in data] validated: list[list[float]] = [] for vector in vectors: if len(vector) != self.dimensions: raise ValueError( f"Jina embedding dimension mismatch: got {len(vector)}, expected {self.dimensions}. " "Adjust EMBEDDING_DIMENSIONS or switch embedding model." ) validated.append(vector) return validated class JinaReranker: def __init__(self, *, api_key: str, base_url: str, model: str) -> None: self.api_key = api_key self.base_url = base_url self.model = model def rerank(self, *, query: str, documents: list[str], top_n: int) -> list[dict[str, Any]]: if not documents: return [] response = requests.post( self.base_url, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", }, json={ "model": self.model, "query": query, "top_n": top_n, "documents": documents, "return_documents": False, }, timeout=60, ) response.raise_for_status() return response.json().get("results", []) class VectorStoreService: def __init__(self) -> None: self.settings = get_settings() if not self.settings.jina_api_key: raise RuntimeError("JINA_API_KEY is required for document embedding and retrieval.") self.splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=150, separators=[ "\n\n", "\n", ". ", "? ", "! ", "; ", ", ", " ", "", ], keep_separator=True, ) self.embeddings = JinaEmbeddings( api_key=self.settings.jina_api_key, base_url=self.settings.jina_api_base, model=self.settings.jina_embedding_model, dimensions=self.settings.embedding_dimensions, ) self.retrieval_router = ( ChatGroq( api_key=self.settings.groq_api_key, model=self.settings.model_name, temperature=0, ) if self.settings.groq_api_key else None ) self.reranker = JinaReranker( api_key=self.settings.jina_api_key, base_url=self.settings.jina_reranker_api_base, model=self.settings.jina_reranker_model, ) def _get_embeddings(self) -> Any: return self.embeddings def _choose_retrieval_sizes( self, *, db: Session, query: str, file_hashes: list[str], requested_k: int, ) -> tuple[int, int]: available_chunks = db.scalar( select(func.count()) .select_from(DocumentChunk) .where(DocumentChunk.file_hash.in_(file_hashes)) ) or 0 if available_chunks <= 0: return 0, 0 if self.retrieval_router is None: raise RuntimeError("GROQ_API_KEY is required for LLM-based retrieval size selection.") prompt = ( "You are a retrieval planner for a RAG system.\n" "Choose how many chunks to keep after reranking and how many vector candidates to send to the reranker.\n" "Return only valid JSON with this exact schema:\n" '{"final_k": 4, "candidate_k": 12}\n\n' "Rules:\n" f"- final_k must be between 1 and {min(8, available_chunks)}\n" f"- candidate_k must be between final_k and {min(30, available_chunks)}\n" "- candidate_k should usually be around 2x to 4x final_k\n" "- Use larger values for broad, comparative, or synthesis-heavy queries\n" "- Use smaller values for narrow fact lookup queries\n\n" f"Query: {query}\n" f"Selected documents: {len(file_hashes)}\n" f"Available chunks: {available_chunks}\n" f"Requested final_k hint: {requested_k}\n" f"Configured minimum final_k: {self.settings.retrieval_k}\n" f"Configured minimum candidate_k: {self.settings.rerank_candidate_k}\n" ) response = self.retrieval_router.invoke(prompt) content = response.content if isinstance(response.content, str) else str(response.content) if "```json" in content: content = content.split("```json", 1)[1].split("```", 1)[0].strip() elif "```" in content: content = content.split("```", 1)[1].split("```", 1)[0].strip() data = json.loads(content) final_k = int(data["final_k"]) candidate_k = int(data["candidate_k"]) final_k = max(1, min(final_k, available_chunks, 8)) candidate_floor = max(final_k, self.settings.rerank_candidate_k) candidate_k = max(final_k, candidate_k) candidate_k = min(max(candidate_floor, candidate_k), available_chunks, 30) return final_k, candidate_k def _rerank_matches(self, *, query: str, matches: list[dict[str, Any]], top_n: int) -> list[dict[str, Any]]: if self.reranker is None or not matches: return matches[:top_n] try: results = self.reranker.rerank( query=query, documents=[match["content"] for match in matches], top_n=min(top_n, len(matches)), ) except requests.RequestException: return matches[:top_n] reranked: list[dict[str, Any]] = [] for item in results: index = item.get("index") if not isinstance(index, int) or index < 0 or index >= len(matches): continue match = dict(matches[index]) score = item.get("relevance_score") if isinstance(score, (int, float)): match["rerank_score"] = float(score) reranked.append(match) return reranked or matches[:top_n] def add_document(self, *, db: Session, document_id: int, file_hash: str, filename: str, pages: list[tuple[int, str]]) -> None: chunk_rows: list[tuple[int | None, str]] = [] for page_number, page_text in pages: if not page_text.strip(): continue page_chunks = self.splitter.split_text(page_text) chunk_rows.extend((page_number, chunk) for chunk in page_chunks if chunk.strip()) chunks = [chunk for _, chunk in chunk_rows] if not chunks: return embeddings_client = self._get_embeddings() embeddings = embeddings_client.embed_documents(chunks) db.execute(delete(DocumentChunk).where(DocumentChunk.document_id == document_id)) rows = [ DocumentChunk( document_id=document_id, file_hash=file_hash, filename=filename, chunk_index=index, page_number=page_number, content=chunk, embedding=embedding, ) for index, ((page_number, chunk), embedding) in enumerate(zip(chunk_rows, embeddings, strict=False)) ] db.add_all(rows) db.flush() def similarity_search(self, *, db: Session, query: str, file_hashes: list[str], k: int = 4) -> list[dict[str, Any]]: if not file_hashes: return [] final_k, candidate_k = self._choose_retrieval_sizes( db=db, query=query, file_hashes=file_hashes, requested_k=k, ) if final_k == 0: return [] query_embedding = self._get_embeddings().embed_query(query) stmt = ( select( DocumentChunk.document_id, DocumentChunk.content, DocumentChunk.filename, DocumentChunk.file_hash, DocumentChunk.chunk_index, DocumentChunk.page_number, DocumentChunk.embedding.cosine_distance(query_embedding).label("distance"), ) .where(DocumentChunk.file_hash.in_(file_hashes)) .order_by(DocumentChunk.embedding.cosine_distance(query_embedding)) .limit(candidate_k) ) results = db.execute(stmt).all() matches: list[dict[str, Any]] = [] for row in results: matches.append( { "content": row.content, "metadata": { "document_id": row.document_id, "filename": row.filename, "file_hash": row.file_hash, "chunk_index": row.chunk_index, "page_number": row.page_number, }, "distance": row.distance, } ) return self._rerank_matches(query=query, matches=matches, top_n=final_k)