| """ |
| Reranking Module for RAG-Anything |
| |
| Provides reranking functionality using: |
| 1. Gemini-based LLM reranking (free tier compatible) |
| 2. Cross-encoder style scoring |
| 3. Relevance-based reordering |
| |
| Reranking is crucial for RAG systems because: |
| - Vector search (embeddings) finds semantically similar text but may miss subtle context |
| - LLMs can deeply understand query intent and document relevance |
| - Reranking improves answer quality by promoting truly relevant chunks to the top |
| |
| Author: RAG-Anything Team |
| Version: 1.0.0 |
| """ |
|
|
| import asyncio |
| import logging |
| import re |
| from typing import List, Dict, Any, Optional, Callable |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class GeminiReranker: |
| """ |
| Reranker using Gemini API for semantic relevance scoring |
| |
| This reranker takes chunks from vector search and re-scores them |
| based on deep semantic understanding using an LLM. |
| |
| Why reranking matters: |
| --------------------- |
| Vector embeddings alone can miss: |
| - Negations ("not effective" vs "effective") |
| - Context dependencies ("aspirin for elderly" vs "aspirin for children") |
| - Query intent ("what causes X" vs "how to prevent X") |
| |
| LLM reranking provides: |
| - Contextual understanding of the query |
| - Semantic relevance beyond keyword matching |
| - Better handling of complex queries |
| """ |
|
|
| def __init__( |
| self, |
| llm_func: Optional[Callable] = None, |
| model_name: str = "models/gemini-2.5-flash", |
| batch_size: int = 5, |
| temperature: float = 0.1 |
| ): |
| """ |
| Initialize Gemini-based reranker |
| |
| Args: |
| llm_func: Optional LLM function to use for reranking |
| model_name: Gemini model to use (default: flash for speed) |
| batch_size: Number of chunks to process in parallel |
| temperature: Temperature for relevance scoring (low=consistent) |
| """ |
| self.llm_func = llm_func |
| self.model_name = model_name |
| self.batch_size = batch_size |
| self.temperature = temperature |
|
|
| async def rerank( |
| self, |
| query: str, |
| chunks: List[Dict[str, Any]], |
| top_k: Optional[int] = None |
| ) -> List[Dict[str, Any]]: |
| """ |
| Rerank chunks based on relevance to query |
| |
| Process: |
| 1. Take top chunks from vector search (e.g., top 50) |
| 2. Score each chunk's relevance using LLM (0-10 scale) |
| 3. Re-order by relevance score |
| 4. Return top_k most relevant chunks |
| |
| Args: |
| query: Search query |
| chunks: List of chunks with 'content' field |
| top_k: Return only top K results (None = return all, reranked) |
| |
| Returns: |
| List of reranked chunks with 'relevance_score' field added |
| """ |
| if not chunks: |
| logger.warning("No chunks to rerank") |
| return [] |
|
|
| if len(chunks) == 1: |
| logger.debug("Only one chunk, skipping reranking") |
| chunks[0]['relevance_score'] = 1.0 |
| return chunks |
|
|
| logger.info(f"Reranking {len(chunks)} chunks for query: {query[:50]}...") |
|
|
| try: |
| |
| scored_chunks = await self._score_chunks_batch(query, chunks) |
|
|
| |
| scored_chunks.sort(key=lambda x: x.get('relevance_score', 0), reverse=True) |
|
|
| |
| if top_k: |
| scored_chunks = scored_chunks[:top_k] |
|
|
| logger.info( |
| f"Reranking complete. Top score: {scored_chunks[0].get('relevance_score', 0):.2f}, " |
| f"Bottom score: {scored_chunks[-1].get('relevance_score', 0):.2f}" |
| ) |
|
|
| return scored_chunks |
|
|
| except Exception as e: |
| logger.error(f"Error during reranking: {e}", exc_info=True) |
| |
| return chunks[:top_k] if top_k else chunks |
|
|
| async def _score_chunks_batch( |
| self, |
| query: str, |
| chunks: List[Dict[str, Any]] |
| ) -> List[Dict[str, Any]]: |
| """ |
| Score chunks in batches for efficiency |
| |
| Args: |
| query: Search query |
| chunks: List of chunks to score |
| |
| Returns: |
| Chunks with relevance_score added |
| """ |
| scored_chunks = [] |
|
|
| |
| for i in range(0, len(chunks), self.batch_size): |
| batch = chunks[i:i + self.batch_size] |
|
|
| |
| tasks = [self._score_chunk(query, chunk) for chunk in batch] |
| batch_scores = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
| |
| for chunk, score_result in zip(batch, batch_scores): |
| if isinstance(score_result, Exception): |
| logger.warning(f"Failed to score chunk: {score_result}") |
| chunk['relevance_score'] = 0.0 |
| else: |
| chunk['relevance_score'] = score_result |
|
|
| scored_chunks.append(chunk) |
|
|
| return scored_chunks |
|
|
| async def _score_chunk( |
| self, |
| query: str, |
| chunk: Dict[str, Any] |
| ) -> float: |
| """ |
| Score a single chunk's relevance to the query using LLM |
| |
| Prompt engineering approach: |
| - Ask LLM to act as a relevance judge |
| - Provide clear scoring criteria (0-10 scale) |
| - Extract numeric score from response |
| |
| Args: |
| query: Search query |
| chunk: Chunk dictionary with 'content' field |
| |
| Returns: |
| Relevance score (0-10) |
| """ |
| content = chunk.get('content', '') |
| if not content: |
| return 0.0 |
|
|
| |
| max_content_length = 1000 |
| if len(content) > max_content_length: |
| content = content[:max_content_length] + "..." |
|
|
| |
| prompt = f"""You are a relevance judge. Score how relevant the following passage is to answering the query. |
| |
| Query: {query} |
| |
| Passage: |
| {content} |
| |
| Scoring criteria: |
| 10 = Directly answers the query with specific, relevant information |
| 8-9 = Highly relevant, provides useful context |
| 6-7 = Somewhat relevant, contains related information |
| 4-5 = Tangentially related, limited usefulness |
| 2-3 = Barely related, mostly off-topic |
| 0-1 = Completely irrelevant |
| |
| Respond with ONLY a number from 0-10. No explanation needed.""" |
|
|
| try: |
| |
| if self.llm_func: |
| response = await self.llm_func( |
| prompt=prompt, |
| temperature=self.temperature, |
| max_tokens=50 |
| ) |
| else: |
| |
| return 5.0 |
|
|
| |
| score = self._extract_score(response) |
| return score |
|
|
| except Exception as e: |
| logger.error(f"Error scoring chunk: {e}") |
| return 5.0 |
|
|
| def _extract_score(self, response: str) -> float: |
| """ |
| Extract numeric score from LLM response |
| |
| Handles various response formats: |
| - "8.5" |
| - "Score: 9" |
| - "The relevance is 7/10" |
| - "8" |
| |
| Args: |
| response: LLM response text |
| |
| Returns: |
| Extracted score (0-10), defaults to 5.0 if parsing fails |
| """ |
| try: |
| |
| response = response.strip() |
|
|
| |
| |
| number_pattern = r'(\d+\.?\d*)' |
| matches = re.findall(number_pattern, response) |
|
|
| if matches: |
| |
| score = float(matches[0]) |
|
|
| |
| score = max(0.0, min(10.0, score)) |
|
|
| return score |
| else: |
| logger.warning(f"Could not extract score from response: {response}") |
| return 5.0 |
|
|
| except Exception as e: |
| logger.error(f"Error extracting score: {e}") |
| return 5.0 |
|
|
|
|
| |
| async def main(): |
| """Example demonstrating reranking""" |
| |
| async def mock_llm(prompt: str, **kwargs) -> str: |
| |
| if "directly" in prompt.lower(): |
| return "9" |
| elif "somewhat" in prompt.lower(): |
| return "6" |
| else: |
| return "3" |
|
|
| |
| reranker = GeminiReranker(llm_func=mock_llm) |
|
|
| |
| query = "What are the side effects of aspirin?" |
|
|
| chunks = [ |
| {"content": "Aspirin can cause stomach bleeding in some patients..."}, |
| {"content": "The history of aspirin dates back to ancient times..."}, |
| {"content": "Common side effects include nausea and heartburn..."}, |
| ] |
|
|
| |
| reranked = await reranker.rerank(query, chunks, top_k=2) |
|
|
| print("Reranked results:") |
| for i, chunk in enumerate(reranked, 1): |
| print(f"{i}. Score: {chunk['relevance_score']:.1f} - {chunk['content'][:50]}...") |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|