OrgAI / backend /reranker.py
Phonex
TheTruthSchool_RAG
167596f
"""
Reranking Module for RAG-Anything
Provides reranking functionality using:
1. Gemini-based LLM reranking (free tier compatible)
2. Cross-encoder style scoring
3. Relevance-based reordering
Reranking is crucial for RAG systems because:
- Vector search (embeddings) finds semantically similar text but may miss subtle context
- LLMs can deeply understand query intent and document relevance
- Reranking improves answer quality by promoting truly relevant chunks to the top
Author: RAG-Anything Team
Version: 1.0.0
"""
import asyncio
import logging
import re
from typing import List, Dict, Any, Optional, Callable
logger = logging.getLogger(__name__)
class GeminiReranker:
"""
Reranker using Gemini API for semantic relevance scoring
This reranker takes chunks from vector search and re-scores them
based on deep semantic understanding using an LLM.
Why reranking matters:
---------------------
Vector embeddings alone can miss:
- Negations ("not effective" vs "effective")
- Context dependencies ("aspirin for elderly" vs "aspirin for children")
- Query intent ("what causes X" vs "how to prevent X")
LLM reranking provides:
- Contextual understanding of the query
- Semantic relevance beyond keyword matching
- Better handling of complex queries
"""
def __init__(
self,
llm_func: Optional[Callable] = None,
model_name: str = "models/gemini-2.5-flash",
batch_size: int = 5,
temperature: float = 0.1
):
"""
Initialize Gemini-based reranker
Args:
llm_func: Optional LLM function to use for reranking
model_name: Gemini model to use (default: flash for speed)
batch_size: Number of chunks to process in parallel
temperature: Temperature for relevance scoring (low=consistent)
"""
self.llm_func = llm_func
self.model_name = model_name
self.batch_size = batch_size
self.temperature = temperature
async def rerank(
self,
query: str,
chunks: List[Dict[str, Any]],
top_k: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Rerank chunks based on relevance to query
Process:
1. Take top chunks from vector search (e.g., top 50)
2. Score each chunk's relevance using LLM (0-10 scale)
3. Re-order by relevance score
4. Return top_k most relevant chunks
Args:
query: Search query
chunks: List of chunks with 'content' field
top_k: Return only top K results (None = return all, reranked)
Returns:
List of reranked chunks with 'relevance_score' field added
"""
if not chunks:
logger.warning("No chunks to rerank")
return []
if len(chunks) == 1:
logger.debug("Only one chunk, skipping reranking")
chunks[0]['relevance_score'] = 1.0
return chunks
logger.info(f"Reranking {len(chunks)} chunks for query: {query[:50]}...")
try:
# Score all chunks in batches
scored_chunks = await self._score_chunks_batch(query, chunks)
# Sort by relevance score (highest first)
scored_chunks.sort(key=lambda x: x.get('relevance_score', 0), reverse=True)
# Return top_k if specified
if top_k:
scored_chunks = scored_chunks[:top_k]
logger.info(
f"Reranking complete. Top score: {scored_chunks[0].get('relevance_score', 0):.2f}, "
f"Bottom score: {scored_chunks[-1].get('relevance_score', 0):.2f}"
)
return scored_chunks
except Exception as e:
logger.error(f"Error during reranking: {e}", exc_info=True)
# Return original order on error
return chunks[:top_k] if top_k else chunks
async def _score_chunks_batch(
self,
query: str,
chunks: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Score chunks in batches for efficiency
Args:
query: Search query
chunks: List of chunks to score
Returns:
Chunks with relevance_score added
"""
scored_chunks = []
# Process in batches to avoid rate limits
for i in range(0, len(chunks), self.batch_size):
batch = chunks[i:i + self.batch_size]
# Score batch concurrently
tasks = [self._score_chunk(query, chunk) for chunk in batch]
batch_scores = await asyncio.gather(*tasks, return_exceptions=True)
# Collect results
for chunk, score_result in zip(batch, batch_scores):
if isinstance(score_result, Exception):
logger.warning(f"Failed to score chunk: {score_result}")
chunk['relevance_score'] = 0.0
else:
chunk['relevance_score'] = score_result
scored_chunks.append(chunk)
return scored_chunks
async def _score_chunk(
self,
query: str,
chunk: Dict[str, Any]
) -> float:
"""
Score a single chunk's relevance to the query using LLM
Prompt engineering approach:
- Ask LLM to act as a relevance judge
- Provide clear scoring criteria (0-10 scale)
- Extract numeric score from response
Args:
query: Search query
chunk: Chunk dictionary with 'content' field
Returns:
Relevance score (0-10)
"""
content = chunk.get('content', '')
if not content:
return 0.0
# Truncate very long chunks to avoid token limits
max_content_length = 1000
if len(content) > max_content_length:
content = content[:max_content_length] + "..."
# Prompt for relevance scoring
prompt = f"""You are a relevance judge. Score how relevant the following passage is to answering the query.
Query: {query}
Passage:
{content}
Scoring criteria:
10 = Directly answers the query with specific, relevant information
8-9 = Highly relevant, provides useful context
6-7 = Somewhat relevant, contains related information
4-5 = Tangentially related, limited usefulness
2-3 = Barely related, mostly off-topic
0-1 = Completely irrelevant
Respond with ONLY a number from 0-10. No explanation needed."""
try:
# Call LLM for scoring
if self.llm_func:
response = await self.llm_func(
prompt=prompt,
temperature=self.temperature,
max_tokens=50 # Increased from 10 to allow for complete score responses
)
else:
# Fallback: no reranking
return 5.0
# Extract numeric score from response
score = self._extract_score(response)
return score
except Exception as e:
logger.error(f"Error scoring chunk: {e}")
return 5.0 # Default mid-range score on error
def _extract_score(self, response: str) -> float:
"""
Extract numeric score from LLM response
Handles various response formats:
- "8.5"
- "Score: 9"
- "The relevance is 7/10"
- "8"
Args:
response: LLM response text
Returns:
Extracted score (0-10), defaults to 5.0 if parsing fails
"""
try:
# Remove whitespace
response = response.strip()
# Try to find a number (int or float) in the response
# Pattern matches: "8", "8.5", "9/10", "Score: 7", etc.
number_pattern = r'(\d+\.?\d*)'
matches = re.findall(number_pattern, response)
if matches:
# Take the first number found
score = float(matches[0])
# Normalize to 0-10 range
score = max(0.0, min(10.0, score))
return score
else:
logger.warning(f"Could not extract score from response: {response}")
return 5.0
except Exception as e:
logger.error(f"Error extracting score: {e}")
return 5.0
# Example usage
async def main():
"""Example demonstrating reranking"""
# Mock LLM function for testing
async def mock_llm(prompt: str, **kwargs) -> str:
# Simulate scoring based on keyword matching
if "directly" in prompt.lower():
return "9"
elif "somewhat" in prompt.lower():
return "6"
else:
return "3"
# Create reranker
reranker = GeminiReranker(llm_func=mock_llm)
# Example query and chunks
query = "What are the side effects of aspirin?"
chunks = [
{"content": "Aspirin can cause stomach bleeding in some patients..."},
{"content": "The history of aspirin dates back to ancient times..."},
{"content": "Common side effects include nausea and heartburn..."},
]
# Rerank
reranked = await reranker.rerank(query, chunks, top_k=2)
print("Reranked results:")
for i, chunk in enumerate(reranked, 1):
print(f"{i}. Score: {chunk['relevance_score']:.1f} - {chunk['content'][:50]}...")
if __name__ == "__main__":
asyncio.run(main())