File size: 9,485 Bytes
167596f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 | """
Reranking Module for RAG-Anything
Provides reranking functionality using:
1. Gemini-based LLM reranking (free tier compatible)
2. Cross-encoder style scoring
3. Relevance-based reordering
Reranking is crucial for RAG systems because:
- Vector search (embeddings) finds semantically similar text but may miss subtle context
- LLMs can deeply understand query intent and document relevance
- Reranking improves answer quality by promoting truly relevant chunks to the top
Author: RAG-Anything Team
Version: 1.0.0
"""
import asyncio
import logging
import re
from typing import List, Dict, Any, Optional, Callable
logger = logging.getLogger(__name__)
class GeminiReranker:
"""
Reranker using Gemini API for semantic relevance scoring
This reranker takes chunks from vector search and re-scores them
based on deep semantic understanding using an LLM.
Why reranking matters:
---------------------
Vector embeddings alone can miss:
- Negations ("not effective" vs "effective")
- Context dependencies ("aspirin for elderly" vs "aspirin for children")
- Query intent ("what causes X" vs "how to prevent X")
LLM reranking provides:
- Contextual understanding of the query
- Semantic relevance beyond keyword matching
- Better handling of complex queries
"""
def __init__(
self,
llm_func: Optional[Callable] = None,
model_name: str = "models/gemini-2.5-flash",
batch_size: int = 5,
temperature: float = 0.1
):
"""
Initialize Gemini-based reranker
Args:
llm_func: Optional LLM function to use for reranking
model_name: Gemini model to use (default: flash for speed)
batch_size: Number of chunks to process in parallel
temperature: Temperature for relevance scoring (low=consistent)
"""
self.llm_func = llm_func
self.model_name = model_name
self.batch_size = batch_size
self.temperature = temperature
async def rerank(
self,
query: str,
chunks: List[Dict[str, Any]],
top_k: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Rerank chunks based on relevance to query
Process:
1. Take top chunks from vector search (e.g., top 50)
2. Score each chunk's relevance using LLM (0-10 scale)
3. Re-order by relevance score
4. Return top_k most relevant chunks
Args:
query: Search query
chunks: List of chunks with 'content' field
top_k: Return only top K results (None = return all, reranked)
Returns:
List of reranked chunks with 'relevance_score' field added
"""
if not chunks:
logger.warning("No chunks to rerank")
return []
if len(chunks) == 1:
logger.debug("Only one chunk, skipping reranking")
chunks[0]['relevance_score'] = 1.0
return chunks
logger.info(f"Reranking {len(chunks)} chunks for query: {query[:50]}...")
try:
# Score all chunks in batches
scored_chunks = await self._score_chunks_batch(query, chunks)
# Sort by relevance score (highest first)
scored_chunks.sort(key=lambda x: x.get('relevance_score', 0), reverse=True)
# Return top_k if specified
if top_k:
scored_chunks = scored_chunks[:top_k]
logger.info(
f"Reranking complete. Top score: {scored_chunks[0].get('relevance_score', 0):.2f}, "
f"Bottom score: {scored_chunks[-1].get('relevance_score', 0):.2f}"
)
return scored_chunks
except Exception as e:
logger.error(f"Error during reranking: {e}", exc_info=True)
# Return original order on error
return chunks[:top_k] if top_k else chunks
async def _score_chunks_batch(
self,
query: str,
chunks: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Score chunks in batches for efficiency
Args:
query: Search query
chunks: List of chunks to score
Returns:
Chunks with relevance_score added
"""
scored_chunks = []
# Process in batches to avoid rate limits
for i in range(0, len(chunks), self.batch_size):
batch = chunks[i:i + self.batch_size]
# Score batch concurrently
tasks = [self._score_chunk(query, chunk) for chunk in batch]
batch_scores = await asyncio.gather(*tasks, return_exceptions=True)
# Collect results
for chunk, score_result in zip(batch, batch_scores):
if isinstance(score_result, Exception):
logger.warning(f"Failed to score chunk: {score_result}")
chunk['relevance_score'] = 0.0
else:
chunk['relevance_score'] = score_result
scored_chunks.append(chunk)
return scored_chunks
async def _score_chunk(
self,
query: str,
chunk: Dict[str, Any]
) -> float:
"""
Score a single chunk's relevance to the query using LLM
Prompt engineering approach:
- Ask LLM to act as a relevance judge
- Provide clear scoring criteria (0-10 scale)
- Extract numeric score from response
Args:
query: Search query
chunk: Chunk dictionary with 'content' field
Returns:
Relevance score (0-10)
"""
content = chunk.get('content', '')
if not content:
return 0.0
# Truncate very long chunks to avoid token limits
max_content_length = 1000
if len(content) > max_content_length:
content = content[:max_content_length] + "..."
# Prompt for relevance scoring
prompt = f"""You are a relevance judge. Score how relevant the following passage is to answering the query.
Query: {query}
Passage:
{content}
Scoring criteria:
10 = Directly answers the query with specific, relevant information
8-9 = Highly relevant, provides useful context
6-7 = Somewhat relevant, contains related information
4-5 = Tangentially related, limited usefulness
2-3 = Barely related, mostly off-topic
0-1 = Completely irrelevant
Respond with ONLY a number from 0-10. No explanation needed."""
try:
# Call LLM for scoring
if self.llm_func:
response = await self.llm_func(
prompt=prompt,
temperature=self.temperature,
max_tokens=50 # Increased from 10 to allow for complete score responses
)
else:
# Fallback: no reranking
return 5.0
# Extract numeric score from response
score = self._extract_score(response)
return score
except Exception as e:
logger.error(f"Error scoring chunk: {e}")
return 5.0 # Default mid-range score on error
def _extract_score(self, response: str) -> float:
"""
Extract numeric score from LLM response
Handles various response formats:
- "8.5"
- "Score: 9"
- "The relevance is 7/10"
- "8"
Args:
response: LLM response text
Returns:
Extracted score (0-10), defaults to 5.0 if parsing fails
"""
try:
# Remove whitespace
response = response.strip()
# Try to find a number (int or float) in the response
# Pattern matches: "8", "8.5", "9/10", "Score: 7", etc.
number_pattern = r'(\d+\.?\d*)'
matches = re.findall(number_pattern, response)
if matches:
# Take the first number found
score = float(matches[0])
# Normalize to 0-10 range
score = max(0.0, min(10.0, score))
return score
else:
logger.warning(f"Could not extract score from response: {response}")
return 5.0
except Exception as e:
logger.error(f"Error extracting score: {e}")
return 5.0
# Example usage
async def main():
"""Example demonstrating reranking"""
# Mock LLM function for testing
async def mock_llm(prompt: str, **kwargs) -> str:
# Simulate scoring based on keyword matching
if "directly" in prompt.lower():
return "9"
elif "somewhat" in prompt.lower():
return "6"
else:
return "3"
# Create reranker
reranker = GeminiReranker(llm_func=mock_llm)
# Example query and chunks
query = "What are the side effects of aspirin?"
chunks = [
{"content": "Aspirin can cause stomach bleeding in some patients..."},
{"content": "The history of aspirin dates back to ancient times..."},
{"content": "Common side effects include nausea and heartburn..."},
]
# Rerank
reranked = await reranker.rerank(query, chunks, top_k=2)
print("Reranked results:")
for i, chunk in enumerate(reranked, 1):
print(f"{i}. Score: {chunk['relevance_score']:.1f} - {chunk['content'][:50]}...")
if __name__ == "__main__":
asyncio.run(main())
|