| """ |
| Reranker Agent |
| |
| Cross-encoder based reranking for improved retrieval precision. |
| Follows FAANG best practices for production RAG systems. |
| |
| Key Features: |
| - LLM-based cross-encoder reranking |
| - Relevance scoring with explanations |
| - Diversity promotion to avoid redundancy |
| - Quality filtering (removes low-quality chunks) |
| - Chunk deduplication |
| """ |
|
|
| from typing import List, Optional, Dict, Any, Tuple |
| from pydantic import BaseModel, Field |
| from loguru import logger |
| from dataclasses import dataclass |
| import json |
| import re |
| from difflib import SequenceMatcher |
|
|
| try: |
| import httpx |
| HTTPX_AVAILABLE = True |
| except ImportError: |
| HTTPX_AVAILABLE = False |
|
|
| from .retriever import RetrievalResult |
|
|
|
|
| class RerankerConfig(BaseModel): |
| """Configuration for reranking.""" |
| |
| model: str = Field(default="llama3.2:3b") |
| base_url: str = Field(default="http://localhost:11434") |
| temperature: float = Field(default=0.1) |
|
|
| |
| top_k: int = Field(default=5, ge=1) |
| min_relevance_score: float = Field(default=0.3, ge=0.0, le=1.0) |
|
|
| |
| enable_diversity: bool = Field(default=True) |
| diversity_threshold: float = Field(default=0.8, description="Max similarity between chunks") |
|
|
| |
| dedup_threshold: float = Field(default=0.9, description="Similarity threshold for dedup") |
|
|
| |
| use_llm_rerank: bool = Field(default=True) |
|
|
|
|
| class RankedResult(BaseModel): |
| """A reranked result with relevance score.""" |
| chunk_id: str |
| document_id: str |
| text: str |
| original_score: float |
| relevance_score: float |
| final_score: float |
| relevance_explanation: Optional[str] = None |
|
|
| |
| page: Optional[int] = None |
| chunk_type: Optional[str] = None |
| source_path: Optional[str] = None |
| metadata: Dict[str, Any] = Field(default_factory=dict) |
| bbox: Optional[Dict[str, float]] = None |
|
|
|
|
| class RerankerAgent: |
| """ |
| Reranks retrieval results for improved precision. |
| |
| Capabilities: |
| 1. Cross-encoder relevance scoring |
| 2. Diversity-aware reranking (MMR-style) |
| 3. Quality filtering |
| 4. Chunk deduplication |
| """ |
|
|
| RERANK_PROMPT = """Score the relevance of this text passage to the given query. |
| |
| Query: {query} |
| |
| Passage: {passage} |
| |
| Score the relevance on a scale of 0-10 where: |
| - 0-2: Completely irrelevant, no useful information |
| - 3-4: Marginally relevant, tangentially related |
| - 5-6: Somewhat relevant, contains some useful information |
| - 7-8: Highly relevant, directly addresses the query |
| - 9-10: Perfectly relevant, comprehensive answer to query |
| |
| Respond with ONLY a JSON object: |
| {{"score": <number>, "explanation": "<brief reason>"}}""" |
|
|
| def __init__(self, config: Optional[RerankerConfig] = None): |
| """ |
| Initialize Reranker Agent. |
| |
| Args: |
| config: Reranker configuration |
| """ |
| self.config = config or RerankerConfig() |
| logger.info(f"RerankerAgent initialized (model={self.config.model})") |
|
|
| def rerank( |
| self, |
| query: str, |
| results: List[RetrievalResult], |
| top_k: Optional[int] = None, |
| ) -> List[RankedResult]: |
| """ |
| Rerank retrieval results by relevance to query. |
| |
| Args: |
| query: Original search query |
| results: Retrieval results to rerank |
| top_k: Number of results to return |
| |
| Returns: |
| Reranked results with relevance scores |
| """ |
| if not results: |
| return [] |
|
|
| top_k = top_k or self.config.top_k |
|
|
| |
| deduped = self._deduplicate(results) |
|
|
| |
| if self.config.use_llm_rerank and HTTPX_AVAILABLE: |
| scored = self._llm_rerank(query, deduped) |
| else: |
| scored = self._heuristic_rerank(query, deduped) |
|
|
| |
| filtered = [ |
| r for r in scored |
| if r.relevance_score >= self.config.min_relevance_score |
| ] |
|
|
| |
| if self.config.enable_diversity: |
| diverse = self._promote_diversity(filtered, top_k) |
| else: |
| diverse = sorted(filtered, key=lambda x: x.final_score, reverse=True)[:top_k] |
|
|
| return diverse |
|
|
| def _deduplicate(self, results: List[RetrievalResult]) -> List[RetrievalResult]: |
| """Remove near-duplicate chunks.""" |
| if not results: |
| return [] |
|
|
| deduped = [results[0]] |
|
|
| for result in results[1:]: |
| is_dup = False |
| for existing in deduped: |
| similarity = self._text_similarity(result.text, existing.text) |
| if similarity > self.config.dedup_threshold: |
| is_dup = True |
| break |
|
|
| if not is_dup: |
| deduped.append(result) |
|
|
| if len(results) != len(deduped): |
| logger.debug(f"Deduplication: {len(results)} -> {len(deduped)} chunks") |
|
|
| return deduped |
|
|
| def _text_similarity(self, text1: str, text2: str) -> float: |
| """Compute text similarity using SequenceMatcher.""" |
| return SequenceMatcher(None, text1.lower(), text2.lower()).ratio() |
|
|
| def _llm_rerank( |
| self, |
| query: str, |
| results: List[RetrievalResult], |
| ) -> List[RankedResult]: |
| """Use LLM for cross-encoder style reranking.""" |
| ranked = [] |
|
|
| for result in results: |
| try: |
| relevance_score, explanation = self._score_passage(query, result.text) |
|
|
| |
| |
| final_score = 0.3 * result.score + 0.7 * (relevance_score / 10.0) |
|
|
| ranked.append(RankedResult( |
| chunk_id=result.chunk_id, |
| document_id=result.document_id, |
| text=result.text, |
| original_score=result.score, |
| relevance_score=relevance_score / 10.0, |
| final_score=final_score, |
| relevance_explanation=explanation, |
| page=result.page, |
| chunk_type=result.chunk_type, |
| source_path=result.source_path, |
| metadata=result.metadata, |
| bbox=result.bbox, |
| )) |
|
|
| except Exception as e: |
| logger.warning(f"Failed to score passage: {e}") |
| |
| ranked.append(RankedResult( |
| chunk_id=result.chunk_id, |
| document_id=result.document_id, |
| text=result.text, |
| original_score=result.score, |
| relevance_score=result.score, |
| final_score=result.score, |
| page=result.page, |
| chunk_type=result.chunk_type, |
| source_path=result.source_path, |
| metadata=result.metadata, |
| bbox=result.bbox, |
| )) |
|
|
| return ranked |
|
|
| def _score_passage(self, query: str, passage: str) -> Tuple[float, str]: |
| """Score a single passage using LLM.""" |
| prompt = self.RERANK_PROMPT.format( |
| query=query, |
| passage=passage[:1000], |
| ) |
|
|
| with httpx.Client(timeout=30.0) as client: |
| response = client.post( |
| f"{self.config.base_url}/api/generate", |
| json={ |
| "model": self.config.model, |
| "prompt": prompt, |
| "stream": False, |
| "options": { |
| "temperature": self.config.temperature, |
| "num_predict": 256, |
| }, |
| }, |
| ) |
| response.raise_for_status() |
| result = response.json() |
|
|
| |
| response_text = result.get("response", "") |
| return self._parse_score_response(response_text) |
|
|
| def _parse_score_response(self, text: str) -> Tuple[float, str]: |
| """Parse score and explanation from LLM response.""" |
| try: |
| |
| json_match = re.search(r'\{[\s\S]*\}', text) |
| if json_match: |
| data = json.loads(json_match.group()) |
| score = float(data.get("score", 5)) |
| explanation = data.get("explanation", "") |
| return min(max(score, 0), 10), explanation |
| except Exception: |
| pass |
|
|
| |
| num_match = re.search(r'\b([0-9]|10)\b', text) |
| if num_match: |
| return float(num_match.group()), "" |
|
|
| |
| return 5.0, "Could not parse score" |
|
|
| def _heuristic_rerank( |
| self, |
| query: str, |
| results: List[RetrievalResult], |
| ) -> List[RankedResult]: |
| """Fast heuristic-based reranking.""" |
| query_terms = set(query.lower().split()) |
| ranked = [] |
|
|
| for result in results: |
| |
| text_lower = result.text.lower() |
|
|
| |
| text_terms = set(text_lower.split()) |
| overlap = len(query_terms & text_terms) / len(query_terms) if query_terms else 0 |
|
|
| |
| phrase_bonus = 0.2 if query.lower() in text_lower else 0 |
|
|
| |
| length = len(result.text) |
| length_score = min(length, 500) / 500 |
|
|
| |
| relevance = 0.5 * overlap + 0.3 * phrase_bonus + 0.2 * length_score |
| final_score = 0.4 * result.score + 0.6 * relevance |
|
|
| ranked.append(RankedResult( |
| chunk_id=result.chunk_id, |
| document_id=result.document_id, |
| text=result.text, |
| original_score=result.score, |
| relevance_score=relevance, |
| final_score=final_score, |
| page=result.page, |
| chunk_type=result.chunk_type, |
| source_path=result.source_path, |
| metadata=result.metadata, |
| bbox=result.bbox, |
| )) |
|
|
| return ranked |
|
|
| def _promote_diversity( |
| self, |
| results: List[RankedResult], |
| top_k: int, |
| ) -> List[RankedResult]: |
| """ |
| Promote diversity using MMR-style selection. |
| |
| Maximal Marginal Relevance balances relevance with diversity. |
| """ |
| if not results: |
| return [] |
|
|
| |
| sorted_results = sorted(results, key=lambda x: x.final_score, reverse=True) |
|
|
| selected = [sorted_results[0]] |
| remaining = sorted_results[1:] |
|
|
| while len(selected) < top_k and remaining: |
| |
| best_mmr = -1 |
| best_idx = 0 |
|
|
| for i, candidate in enumerate(remaining): |
| |
| relevance = candidate.final_score |
|
|
| |
| max_sim = max( |
| self._text_similarity(candidate.text, s.text) |
| for s in selected |
| ) |
|
|
| |
| |
| mmr = 0.7 * relevance - 0.3 * max_sim |
|
|
| if mmr > best_mmr: |
| best_mmr = mmr |
| best_idx = i |
|
|
| selected.append(remaining.pop(best_idx)) |
|
|
| return selected |
|
|