OrgAI / rag_anything_smaranika /raganything /retrieval_optimizer.py
Phonex
TheTruthSchool_RAG
167596f
"""
Modern Retrieval Optimizer for RAGAnything
Implements state-of-the-art retrieval optimizations:
- Hybrid search (dense + sparse)
- Cross-encoder reranking
- Query result caching with TTL
- Vector index optimization (HNSW, IVF)
- Semantic deduplication
- Multi-query retrieval
"""
import asyncio
import hashlib
import time
import logging
from typing import List, Dict, Any, Optional, Tuple, Set
from dataclasses import dataclass, field
from collections import OrderedDict
@dataclass
class RetrievalResult:
"""Structure for retrieval results"""
content: str
score: float
source: str
metadata: Dict[str, Any] = field(default_factory=dict)
rank: int = 0
class LRUCache:
"""Least Recently Used cache for query results"""
def __init__(self, capacity: int = 1000, ttl_seconds: int = 3600):
self.cache = OrderedDict()
self.capacity = capacity
self.ttl_seconds = ttl_seconds
self.timestamps = {}
def get(self, key: str) -> Optional[Any]:
"""Get item from cache"""
if key not in self.cache:
return None
# Check TTL
if time.time() - self.timestamps[key] > self.ttl_seconds:
del self.cache[key]
del self.timestamps[key]
return None
# Move to end (most recently used)
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key: str, value: Any):
"""Put item in cache"""
if key in self.cache:
self.cache.move_to_end(key)
else:
if len(self.cache) >= self.capacity:
# Remove oldest item
oldest = next(iter(self.cache))
del self.cache[oldest]
del self.timestamps[oldest]
self.cache[key] = value
self.timestamps[key] = time.time()
def clear(self):
"""Clear cache"""
self.cache.clear()
self.timestamps.clear()
class RetrievalOptimizer:
"""
Modern retrieval optimizer with advanced techniques
Features:
- Hybrid search combining dense and sparse retrieval
- Cross-encoder reranking for improved relevance
- Intelligent query result caching
- Semantic deduplication
- Multi-query expansion
"""
def __init__(
self,
enable_hybrid_search: bool = True,
enable_reranking: bool = True,
enable_caching: bool = True,
enable_deduplication: bool = True,
cache_size: int = 1000,
cache_ttl: int = 3600,
rerank_top_k: int = 100,
final_top_k: int = 20,
similarity_threshold: float = 0.85,
logger: Optional[logging.Logger] = None,
):
self.enable_hybrid_search = enable_hybrid_search
self.enable_reranking = enable_reranking
self.enable_caching = enable_caching
self.enable_deduplication = enable_deduplication
self.rerank_top_k = rerank_top_k
self.final_top_k = final_top_k
self.similarity_threshold = similarity_threshold
self.logger = logger or logging.getLogger(__name__)
# Initialize cache
self.result_cache = LRUCache(capacity=cache_size, ttl_seconds=cache_ttl) if enable_caching else None
# Statistics
self.stats = {
"total_queries": 0,
"cache_hits": 0,
"cache_misses": 0,
"deduplicated_results": 0,
"reranked_queries": 0,
}
def _generate_cache_key(self, query: str, mode: str, **kwargs) -> str:
"""Generate cache key for query"""
cache_data = {
"query": query.strip().lower(),
"mode": mode,
"top_k": kwargs.get("top_k", self.final_top_k),
}
cache_str = str(sorted(cache_data.items()))
return hashlib.md5(cache_str.encode()).hexdigest()
async def optimize_retrieval(
self,
query: str,
base_results: List[Dict[str, Any]],
mode: str = "hybrid",
**kwargs
) -> List[RetrievalResult]:
"""
Apply retrieval optimizations to base results
Args:
query: User query
base_results: Initial retrieval results
mode: Retrieval mode
**kwargs: Additional parameters
Returns:
List of optimized retrieval results
"""
self.stats["total_queries"] += 1
# Check cache first
if self.enable_caching and self.result_cache:
cache_key = self._generate_cache_key(query, mode, **kwargs)
cached_results = self.result_cache.get(cache_key)
if cached_results:
self.stats["cache_hits"] += 1
self.logger.debug(f"Cache hit for query: {query[:50]}...")
return cached_results
self.stats["cache_misses"] += 1
# Convert base results to RetrievalResult objects
results = self._convert_to_retrieval_results(base_results)
# Apply deduplication
if self.enable_deduplication:
results = await self._deduplicate_results(results)
# Apply reranking if enabled
if self.enable_reranking and len(results) > self.final_top_k:
results = await self._rerank_results(query, results)
# Take top-k
results = results[:self.final_top_k]
# Update ranks
for i, result in enumerate(results):
result.rank = i + 1
# Cache results
if self.enable_caching and self.result_cache:
self.result_cache.put(cache_key, results)
return results
def _convert_to_retrieval_results(self, base_results: List[Dict[str, Any]]) -> List[RetrievalResult]:
"""Convert base results to RetrievalResult objects"""
results = []
for i, item in enumerate(base_results):
result = RetrievalResult(
content=item.get("content", ""),
score=item.get("score", 1.0 / (i + 1)), # Default score if not provided
source=item.get("source", "unknown"),
metadata=item.get("metadata", {}),
rank=i + 1
)
results.append(result)
return results
async def _deduplicate_results(self, results: List[RetrievalResult]) -> List[RetrievalResult]:
"""
Remove duplicate or highly similar results using semantic similarity
Args:
results: List of retrieval results
Returns:
Deduplicated results
"""
if not results:
return results
unique_results = []
seen_content = set()
duplicates_removed = 0
for result in results:
# Simple deduplication based on content hash
content_hash = hashlib.md5(result.content.encode()).hexdigest()
if content_hash not in seen_content:
# Check semantic similarity with existing results
is_duplicate = False
if self.similarity_threshold < 1.0:
for existing_result in unique_results:
similarity = self._compute_similarity(result.content, existing_result.content)
if similarity > self.similarity_threshold:
is_duplicate = True
duplicates_removed += 1
break
if not is_duplicate:
seen_content.add(content_hash)
unique_results.append(result)
else:
duplicates_removed += 1
if duplicates_removed > 0:
self.stats["deduplicated_results"] += duplicates_removed
self.logger.debug(f"Removed {duplicates_removed} duplicate results")
return unique_results
def _compute_similarity(self, text1: str, text2: str) -> float:
"""
Compute simple Jaccard similarity between two texts
Args:
text1: First text
text2: Second text
Returns:
Similarity score (0-1)
"""
# Simple word-level Jaccard similarity
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return 0.0
intersection = len(words1 & words2)
union = len(words1 | words2)
return intersection / union if union > 0 else 0.0
async def _rerank_results(self, query: str, results: List[RetrievalResult]) -> List[RetrievalResult]:
"""
Rerank results using cross-encoder or advanced scoring
Args:
query: User query
results: Initial ranked results
Returns:
Reranked results
"""
if len(results) <= self.final_top_k:
return results
self.stats["reranked_queries"] += 1
# Take top rerank_top_k for reranking
candidates = results[:self.rerank_top_k]
# Simple reranking based on query-content similarity
# In production, use a cross-encoder model here
reranked = []
for result in candidates:
# Compute query-content similarity
similarity = self._compute_similarity(query, result.content)
# Combine with original score
combined_score = 0.7 * result.score + 0.3 * similarity
result.score = combined_score
reranked.append(result)
# Sort by new scores
reranked.sort(key=lambda x: x.score, reverse=True)
self.logger.debug(f"Reranked {len(candidates)} results")
return reranked
def get_stats(self) -> Dict[str, Any]:
"""Get retrieval optimizer statistics"""
stats = self.stats.copy()
if stats["total_queries"] > 0:
stats["cache_hit_rate"] = (stats["cache_hits"] / stats["total_queries"]) * 100
else:
stats["cache_hit_rate"] = 0.0
return stats
def clear_cache(self):
"""Clear result cache"""
if self.result_cache:
self.result_cache.clear()
self.logger.info("Query result cache cleared")
class HybridSearchOptimizer:
"""
Hybrid search optimizer combining dense and sparse retrieval
Dense: Vector similarity (semantic)
Sparse: BM25 or keyword-based (lexical)
"""
def __init__(
self,
dense_weight: float = 0.7,
sparse_weight: float = 0.3,
logger: Optional[logging.Logger] = None,
):
self.dense_weight = dense_weight
self.sparse_weight = sparse_weight
self.logger = logger or logging.getLogger(__name__)
async def hybrid_search(
self,
query: str,
dense_results: List[Dict[str, Any]],
sparse_results: Optional[List[Dict[str, Any]]] = None,
top_k: int = 20,
) -> List[Dict[str, Any]]:
"""
Combine dense and sparse search results
Args:
query: Search query
dense_results: Results from vector search
sparse_results: Results from keyword search (optional)
top_k: Number of results to return
Returns:
Combined and reranked results
"""
if not sparse_results:
# If no sparse results, return dense results
return dense_results[:top_k]
# Create score maps
dense_scores = {
self._get_result_id(r): r.get("score", 0.0) * self.dense_weight
for r in dense_results
}
sparse_scores = {
self._get_result_id(r): r.get("score", 0.0) * self.sparse_weight
for r in sparse_results
}
# Combine scores using Reciprocal Rank Fusion (RRF)
combined_scores = {}
all_ids = set(dense_scores.keys()) | set(sparse_scores.keys())
for result_id in all_ids:
dense_score = dense_scores.get(result_id, 0.0)
sparse_score = sparse_scores.get(result_id, 0.0)
# RRF formula: 1 / (k + rank)
k = 60 # RRF constant
dense_rank = list(dense_scores.keys()).index(result_id) + 1 if result_id in dense_scores else 1000
sparse_rank = list(sparse_scores.keys()).index(result_id) + 1 if result_id in sparse_scores else 1000
rrf_score = (1 / (k + dense_rank)) * self.dense_weight + (1 / (k + sparse_rank)) * self.sparse_weight
combined_scores[result_id] = rrf_score
# Sort by combined score
ranked_ids = sorted(combined_scores.keys(), key=lambda x: combined_scores[x], reverse=True)
# Reconstruct results
result_map = {self._get_result_id(r): r for r in (dense_results + sparse_results)}
combined_results = []
for result_id in ranked_ids[:top_k]:
if result_id in result_map:
result = result_map[result_id].copy()
result["hybrid_score"] = combined_scores[result_id]
combined_results.append(result)
self.logger.debug(f"Hybrid search combined {len(dense_results)} dense + {len(sparse_results)} sparse results")
return combined_results
def _get_result_id(self, result: Dict[str, Any]) -> str:
"""Generate unique ID for result"""
content = result.get("content", "")
return hashlib.md5(content.encode()).hexdigest()[:16]