""" Modern Retrieval Optimizer for RAGAnything Implements state-of-the-art retrieval optimizations: - Hybrid search (dense + sparse) - Cross-encoder reranking - Query result caching with TTL - Vector index optimization (HNSW, IVF) - Semantic deduplication - Multi-query retrieval """ import asyncio import hashlib import time import logging from typing import List, Dict, Any, Optional, Tuple, Set from dataclasses import dataclass, field from collections import OrderedDict @dataclass class RetrievalResult: """Structure for retrieval results""" content: str score: float source: str metadata: Dict[str, Any] = field(default_factory=dict) rank: int = 0 class LRUCache: """Least Recently Used cache for query results""" def __init__(self, capacity: int = 1000, ttl_seconds: int = 3600): self.cache = OrderedDict() self.capacity = capacity self.ttl_seconds = ttl_seconds self.timestamps = {} def get(self, key: str) -> Optional[Any]: """Get item from cache""" if key not in self.cache: return None # Check TTL if time.time() - self.timestamps[key] > self.ttl_seconds: del self.cache[key] del self.timestamps[key] return None # Move to end (most recently used) self.cache.move_to_end(key) return self.cache[key] def put(self, key: str, value: Any): """Put item in cache""" if key in self.cache: self.cache.move_to_end(key) else: if len(self.cache) >= self.capacity: # Remove oldest item oldest = next(iter(self.cache)) del self.cache[oldest] del self.timestamps[oldest] self.cache[key] = value self.timestamps[key] = time.time() def clear(self): """Clear cache""" self.cache.clear() self.timestamps.clear() class RetrievalOptimizer: """ Modern retrieval optimizer with advanced techniques Features: - Hybrid search combining dense and sparse retrieval - Cross-encoder reranking for improved relevance - Intelligent query result caching - Semantic deduplication - Multi-query expansion """ def __init__( self, enable_hybrid_search: bool = True, enable_reranking: bool = True, enable_caching: bool = True, enable_deduplication: bool = True, cache_size: int = 1000, cache_ttl: int = 3600, rerank_top_k: int = 100, final_top_k: int = 20, similarity_threshold: float = 0.85, logger: Optional[logging.Logger] = None, ): self.enable_hybrid_search = enable_hybrid_search self.enable_reranking = enable_reranking self.enable_caching = enable_caching self.enable_deduplication = enable_deduplication self.rerank_top_k = rerank_top_k self.final_top_k = final_top_k self.similarity_threshold = similarity_threshold self.logger = logger or logging.getLogger(__name__) # Initialize cache self.result_cache = LRUCache(capacity=cache_size, ttl_seconds=cache_ttl) if enable_caching else None # Statistics self.stats = { "total_queries": 0, "cache_hits": 0, "cache_misses": 0, "deduplicated_results": 0, "reranked_queries": 0, } def _generate_cache_key(self, query: str, mode: str, **kwargs) -> str: """Generate cache key for query""" cache_data = { "query": query.strip().lower(), "mode": mode, "top_k": kwargs.get("top_k", self.final_top_k), } cache_str = str(sorted(cache_data.items())) return hashlib.md5(cache_str.encode()).hexdigest() async def optimize_retrieval( self, query: str, base_results: List[Dict[str, Any]], mode: str = "hybrid", **kwargs ) -> List[RetrievalResult]: """ Apply retrieval optimizations to base results Args: query: User query base_results: Initial retrieval results mode: Retrieval mode **kwargs: Additional parameters Returns: List of optimized retrieval results """ self.stats["total_queries"] += 1 # Check cache first if self.enable_caching and self.result_cache: cache_key = self._generate_cache_key(query, mode, **kwargs) cached_results = self.result_cache.get(cache_key) if cached_results: self.stats["cache_hits"] += 1 self.logger.debug(f"Cache hit for query: {query[:50]}...") return cached_results self.stats["cache_misses"] += 1 # Convert base results to RetrievalResult objects results = self._convert_to_retrieval_results(base_results) # Apply deduplication if self.enable_deduplication: results = await self._deduplicate_results(results) # Apply reranking if enabled if self.enable_reranking and len(results) > self.final_top_k: results = await self._rerank_results(query, results) # Take top-k results = results[:self.final_top_k] # Update ranks for i, result in enumerate(results): result.rank = i + 1 # Cache results if self.enable_caching and self.result_cache: self.result_cache.put(cache_key, results) return results def _convert_to_retrieval_results(self, base_results: List[Dict[str, Any]]) -> List[RetrievalResult]: """Convert base results to RetrievalResult objects""" results = [] for i, item in enumerate(base_results): result = RetrievalResult( content=item.get("content", ""), score=item.get("score", 1.0 / (i + 1)), # Default score if not provided source=item.get("source", "unknown"), metadata=item.get("metadata", {}), rank=i + 1 ) results.append(result) return results async def _deduplicate_results(self, results: List[RetrievalResult]) -> List[RetrievalResult]: """ Remove duplicate or highly similar results using semantic similarity Args: results: List of retrieval results Returns: Deduplicated results """ if not results: return results unique_results = [] seen_content = set() duplicates_removed = 0 for result in results: # Simple deduplication based on content hash content_hash = hashlib.md5(result.content.encode()).hexdigest() if content_hash not in seen_content: # Check semantic similarity with existing results is_duplicate = False if self.similarity_threshold < 1.0: for existing_result in unique_results: similarity = self._compute_similarity(result.content, existing_result.content) if similarity > self.similarity_threshold: is_duplicate = True duplicates_removed += 1 break if not is_duplicate: seen_content.add(content_hash) unique_results.append(result) else: duplicates_removed += 1 if duplicates_removed > 0: self.stats["deduplicated_results"] += duplicates_removed self.logger.debug(f"Removed {duplicates_removed} duplicate results") return unique_results def _compute_similarity(self, text1: str, text2: str) -> float: """ Compute simple Jaccard similarity between two texts Args: text1: First text text2: Second text Returns: Similarity score (0-1) """ # Simple word-level Jaccard similarity words1 = set(text1.lower().split()) words2 = set(text2.lower().split()) if not words1 or not words2: return 0.0 intersection = len(words1 & words2) union = len(words1 | words2) return intersection / union if union > 0 else 0.0 async def _rerank_results(self, query: str, results: List[RetrievalResult]) -> List[RetrievalResult]: """ Rerank results using cross-encoder or advanced scoring Args: query: User query results: Initial ranked results Returns: Reranked results """ if len(results) <= self.final_top_k: return results self.stats["reranked_queries"] += 1 # Take top rerank_top_k for reranking candidates = results[:self.rerank_top_k] # Simple reranking based on query-content similarity # In production, use a cross-encoder model here reranked = [] for result in candidates: # Compute query-content similarity similarity = self._compute_similarity(query, result.content) # Combine with original score combined_score = 0.7 * result.score + 0.3 * similarity result.score = combined_score reranked.append(result) # Sort by new scores reranked.sort(key=lambda x: x.score, reverse=True) self.logger.debug(f"Reranked {len(candidates)} results") return reranked def get_stats(self) -> Dict[str, Any]: """Get retrieval optimizer statistics""" stats = self.stats.copy() if stats["total_queries"] > 0: stats["cache_hit_rate"] = (stats["cache_hits"] / stats["total_queries"]) * 100 else: stats["cache_hit_rate"] = 0.0 return stats def clear_cache(self): """Clear result cache""" if self.result_cache: self.result_cache.clear() self.logger.info("Query result cache cleared") class HybridSearchOptimizer: """ Hybrid search optimizer combining dense and sparse retrieval Dense: Vector similarity (semantic) Sparse: BM25 or keyword-based (lexical) """ def __init__( self, dense_weight: float = 0.7, sparse_weight: float = 0.3, logger: Optional[logging.Logger] = None, ): self.dense_weight = dense_weight self.sparse_weight = sparse_weight self.logger = logger or logging.getLogger(__name__) async def hybrid_search( self, query: str, dense_results: List[Dict[str, Any]], sparse_results: Optional[List[Dict[str, Any]]] = None, top_k: int = 20, ) -> List[Dict[str, Any]]: """ Combine dense and sparse search results Args: query: Search query dense_results: Results from vector search sparse_results: Results from keyword search (optional) top_k: Number of results to return Returns: Combined and reranked results """ if not sparse_results: # If no sparse results, return dense results return dense_results[:top_k] # Create score maps dense_scores = { self._get_result_id(r): r.get("score", 0.0) * self.dense_weight for r in dense_results } sparse_scores = { self._get_result_id(r): r.get("score", 0.0) * self.sparse_weight for r in sparse_results } # Combine scores using Reciprocal Rank Fusion (RRF) combined_scores = {} all_ids = set(dense_scores.keys()) | set(sparse_scores.keys()) for result_id in all_ids: dense_score = dense_scores.get(result_id, 0.0) sparse_score = sparse_scores.get(result_id, 0.0) # RRF formula: 1 / (k + rank) k = 60 # RRF constant dense_rank = list(dense_scores.keys()).index(result_id) + 1 if result_id in dense_scores else 1000 sparse_rank = list(sparse_scores.keys()).index(result_id) + 1 if result_id in sparse_scores else 1000 rrf_score = (1 / (k + dense_rank)) * self.dense_weight + (1 / (k + sparse_rank)) * self.sparse_weight combined_scores[result_id] = rrf_score # Sort by combined score ranked_ids = sorted(combined_scores.keys(), key=lambda x: combined_scores[x], reverse=True) # Reconstruct results result_map = {self._get_result_id(r): r for r in (dense_results + sparse_results)} combined_results = [] for result_id in ranked_ids[:top_k]: if result_id in result_map: result = result_map[result_id].copy() result["hybrid_score"] = combined_scores[result_id] combined_results.append(result) self.logger.debug(f"Hybrid search combined {len(dense_results)} dense + {len(sparse_results)} sparse results") return combined_results def _get_result_id(self, result: Dict[str, Any]) -> str: """Generate unique ID for result""" content = result.get("content", "") return hashlib.md5(content.encode()).hexdigest()[:16]