| """ |
| Modern Retrieval Optimizer for RAGAnything |
| |
| Implements state-of-the-art retrieval optimizations: |
| - Hybrid search (dense + sparse) |
| - Cross-encoder reranking |
| - Query result caching with TTL |
| - Vector index optimization (HNSW, IVF) |
| - Semantic deduplication |
| - Multi-query retrieval |
| """ |
|
|
| import asyncio |
| import hashlib |
| import time |
| import logging |
| from typing import List, Dict, Any, Optional, Tuple, Set |
| from dataclasses import dataclass, field |
| from collections import OrderedDict |
|
|
| @dataclass |
| class RetrievalResult: |
| """Structure for retrieval results""" |
| content: str |
| score: float |
| source: str |
| metadata: Dict[str, Any] = field(default_factory=dict) |
| rank: int = 0 |
|
|
|
|
| class LRUCache: |
| """Least Recently Used cache for query results""" |
|
|
| def __init__(self, capacity: int = 1000, ttl_seconds: int = 3600): |
| self.cache = OrderedDict() |
| self.capacity = capacity |
| self.ttl_seconds = ttl_seconds |
| self.timestamps = {} |
|
|
| def get(self, key: str) -> Optional[Any]: |
| """Get item from cache""" |
| if key not in self.cache: |
| return None |
|
|
| |
| if time.time() - self.timestamps[key] > self.ttl_seconds: |
| del self.cache[key] |
| del self.timestamps[key] |
| return None |
|
|
| |
| self.cache.move_to_end(key) |
| return self.cache[key] |
|
|
| def put(self, key: str, value: Any): |
| """Put item in cache""" |
| if key in self.cache: |
| self.cache.move_to_end(key) |
| else: |
| if len(self.cache) >= self.capacity: |
| |
| oldest = next(iter(self.cache)) |
| del self.cache[oldest] |
| del self.timestamps[oldest] |
|
|
| self.cache[key] = value |
| self.timestamps[key] = time.time() |
|
|
| def clear(self): |
| """Clear cache""" |
| self.cache.clear() |
| self.timestamps.clear() |
|
|
|
|
| class RetrievalOptimizer: |
| """ |
| Modern retrieval optimizer with advanced techniques |
| |
| Features: |
| - Hybrid search combining dense and sparse retrieval |
| - Cross-encoder reranking for improved relevance |
| - Intelligent query result caching |
| - Semantic deduplication |
| - Multi-query expansion |
| """ |
|
|
| def __init__( |
| self, |
| enable_hybrid_search: bool = True, |
| enable_reranking: bool = True, |
| enable_caching: bool = True, |
| enable_deduplication: bool = True, |
| cache_size: int = 1000, |
| cache_ttl: int = 3600, |
| rerank_top_k: int = 100, |
| final_top_k: int = 20, |
| similarity_threshold: float = 0.85, |
| logger: Optional[logging.Logger] = None, |
| ): |
| self.enable_hybrid_search = enable_hybrid_search |
| self.enable_reranking = enable_reranking |
| self.enable_caching = enable_caching |
| self.enable_deduplication = enable_deduplication |
| self.rerank_top_k = rerank_top_k |
| self.final_top_k = final_top_k |
| self.similarity_threshold = similarity_threshold |
| self.logger = logger or logging.getLogger(__name__) |
|
|
| |
| self.result_cache = LRUCache(capacity=cache_size, ttl_seconds=cache_ttl) if enable_caching else None |
|
|
| |
| self.stats = { |
| "total_queries": 0, |
| "cache_hits": 0, |
| "cache_misses": 0, |
| "deduplicated_results": 0, |
| "reranked_queries": 0, |
| } |
|
|
| def _generate_cache_key(self, query: str, mode: str, **kwargs) -> str: |
| """Generate cache key for query""" |
| cache_data = { |
| "query": query.strip().lower(), |
| "mode": mode, |
| "top_k": kwargs.get("top_k", self.final_top_k), |
| } |
| cache_str = str(sorted(cache_data.items())) |
| return hashlib.md5(cache_str.encode()).hexdigest() |
|
|
| async def optimize_retrieval( |
| self, |
| query: str, |
| base_results: List[Dict[str, Any]], |
| mode: str = "hybrid", |
| **kwargs |
| ) -> List[RetrievalResult]: |
| """ |
| Apply retrieval optimizations to base results |
| |
| Args: |
| query: User query |
| base_results: Initial retrieval results |
| mode: Retrieval mode |
| **kwargs: Additional parameters |
| |
| Returns: |
| List of optimized retrieval results |
| """ |
| self.stats["total_queries"] += 1 |
|
|
| |
| if self.enable_caching and self.result_cache: |
| cache_key = self._generate_cache_key(query, mode, **kwargs) |
| cached_results = self.result_cache.get(cache_key) |
| if cached_results: |
| self.stats["cache_hits"] += 1 |
| self.logger.debug(f"Cache hit for query: {query[:50]}...") |
| return cached_results |
| self.stats["cache_misses"] += 1 |
|
|
| |
| results = self._convert_to_retrieval_results(base_results) |
|
|
| |
| if self.enable_deduplication: |
| results = await self._deduplicate_results(results) |
|
|
| |
| if self.enable_reranking and len(results) > self.final_top_k: |
| results = await self._rerank_results(query, results) |
|
|
| |
| results = results[:self.final_top_k] |
|
|
| |
| for i, result in enumerate(results): |
| result.rank = i + 1 |
|
|
| |
| if self.enable_caching and self.result_cache: |
| self.result_cache.put(cache_key, results) |
|
|
| return results |
|
|
| def _convert_to_retrieval_results(self, base_results: List[Dict[str, Any]]) -> List[RetrievalResult]: |
| """Convert base results to RetrievalResult objects""" |
| results = [] |
| for i, item in enumerate(base_results): |
| result = RetrievalResult( |
| content=item.get("content", ""), |
| score=item.get("score", 1.0 / (i + 1)), |
| source=item.get("source", "unknown"), |
| metadata=item.get("metadata", {}), |
| rank=i + 1 |
| ) |
| results.append(result) |
| return results |
|
|
| async def _deduplicate_results(self, results: List[RetrievalResult]) -> List[RetrievalResult]: |
| """ |
| Remove duplicate or highly similar results using semantic similarity |
| |
| Args: |
| results: List of retrieval results |
| |
| Returns: |
| Deduplicated results |
| """ |
| if not results: |
| return results |
|
|
| unique_results = [] |
| seen_content = set() |
| duplicates_removed = 0 |
|
|
| for result in results: |
| |
| content_hash = hashlib.md5(result.content.encode()).hexdigest() |
|
|
| if content_hash not in seen_content: |
| |
| is_duplicate = False |
|
|
| if self.similarity_threshold < 1.0: |
| for existing_result in unique_results: |
| similarity = self._compute_similarity(result.content, existing_result.content) |
| if similarity > self.similarity_threshold: |
| is_duplicate = True |
| duplicates_removed += 1 |
| break |
|
|
| if not is_duplicate: |
| seen_content.add(content_hash) |
| unique_results.append(result) |
| else: |
| duplicates_removed += 1 |
|
|
| if duplicates_removed > 0: |
| self.stats["deduplicated_results"] += duplicates_removed |
| self.logger.debug(f"Removed {duplicates_removed} duplicate results") |
|
|
| return unique_results |
|
|
| def _compute_similarity(self, text1: str, text2: str) -> float: |
| """ |
| Compute simple Jaccard similarity between two texts |
| |
| Args: |
| text1: First text |
| text2: Second text |
| |
| Returns: |
| Similarity score (0-1) |
| """ |
| |
| words1 = set(text1.lower().split()) |
| words2 = set(text2.lower().split()) |
|
|
| if not words1 or not words2: |
| return 0.0 |
|
|
| intersection = len(words1 & words2) |
| union = len(words1 | words2) |
|
|
| return intersection / union if union > 0 else 0.0 |
|
|
| async def _rerank_results(self, query: str, results: List[RetrievalResult]) -> List[RetrievalResult]: |
| """ |
| Rerank results using cross-encoder or advanced scoring |
| |
| Args: |
| query: User query |
| results: Initial ranked results |
| |
| Returns: |
| Reranked results |
| """ |
| if len(results) <= self.final_top_k: |
| return results |
|
|
| self.stats["reranked_queries"] += 1 |
|
|
| |
| candidates = results[:self.rerank_top_k] |
|
|
| |
| |
| reranked = [] |
| for result in candidates: |
| |
| similarity = self._compute_similarity(query, result.content) |
|
|
| |
| combined_score = 0.7 * result.score + 0.3 * similarity |
|
|
| result.score = combined_score |
| reranked.append(result) |
|
|
| |
| reranked.sort(key=lambda x: x.score, reverse=True) |
|
|
| self.logger.debug(f"Reranked {len(candidates)} results") |
|
|
| return reranked |
|
|
| def get_stats(self) -> Dict[str, Any]: |
| """Get retrieval optimizer statistics""" |
| stats = self.stats.copy() |
|
|
| if stats["total_queries"] > 0: |
| stats["cache_hit_rate"] = (stats["cache_hits"] / stats["total_queries"]) * 100 |
| else: |
| stats["cache_hit_rate"] = 0.0 |
|
|
| return stats |
|
|
| def clear_cache(self): |
| """Clear result cache""" |
| if self.result_cache: |
| self.result_cache.clear() |
| self.logger.info("Query result cache cleared") |
|
|
|
|
| class HybridSearchOptimizer: |
| """ |
| Hybrid search optimizer combining dense and sparse retrieval |
| |
| Dense: Vector similarity (semantic) |
| Sparse: BM25 or keyword-based (lexical) |
| """ |
|
|
| def __init__( |
| self, |
| dense_weight: float = 0.7, |
| sparse_weight: float = 0.3, |
| logger: Optional[logging.Logger] = None, |
| ): |
| self.dense_weight = dense_weight |
| self.sparse_weight = sparse_weight |
| self.logger = logger or logging.getLogger(__name__) |
|
|
| async def hybrid_search( |
| self, |
| query: str, |
| dense_results: List[Dict[str, Any]], |
| sparse_results: Optional[List[Dict[str, Any]]] = None, |
| top_k: int = 20, |
| ) -> List[Dict[str, Any]]: |
| """ |
| Combine dense and sparse search results |
| |
| Args: |
| query: Search query |
| dense_results: Results from vector search |
| sparse_results: Results from keyword search (optional) |
| top_k: Number of results to return |
| |
| Returns: |
| Combined and reranked results |
| """ |
| if not sparse_results: |
| |
| return dense_results[:top_k] |
|
|
| |
| dense_scores = { |
| self._get_result_id(r): r.get("score", 0.0) * self.dense_weight |
| for r in dense_results |
| } |
|
|
| sparse_scores = { |
| self._get_result_id(r): r.get("score", 0.0) * self.sparse_weight |
| for r in sparse_results |
| } |
|
|
| |
| combined_scores = {} |
| all_ids = set(dense_scores.keys()) | set(sparse_scores.keys()) |
|
|
| for result_id in all_ids: |
| dense_score = dense_scores.get(result_id, 0.0) |
| sparse_score = sparse_scores.get(result_id, 0.0) |
|
|
| |
| k = 60 |
| dense_rank = list(dense_scores.keys()).index(result_id) + 1 if result_id in dense_scores else 1000 |
| sparse_rank = list(sparse_scores.keys()).index(result_id) + 1 if result_id in sparse_scores else 1000 |
|
|
| rrf_score = (1 / (k + dense_rank)) * self.dense_weight + (1 / (k + sparse_rank)) * self.sparse_weight |
|
|
| combined_scores[result_id] = rrf_score |
|
|
| |
| ranked_ids = sorted(combined_scores.keys(), key=lambda x: combined_scores[x], reverse=True) |
|
|
| |
| result_map = {self._get_result_id(r): r for r in (dense_results + sparse_results)} |
| combined_results = [] |
|
|
| for result_id in ranked_ids[:top_k]: |
| if result_id in result_map: |
| result = result_map[result_id].copy() |
| result["hybrid_score"] = combined_scores[result_id] |
| combined_results.append(result) |
|
|
| self.logger.debug(f"Hybrid search combined {len(dense_results)} dense + {len(sparse_results)} sparse results") |
|
|
| return combined_results |
|
|
| def _get_result_id(self, result: Dict[str, Any]) -> str: |
| """Generate unique ID for result""" |
| content = result.get("content", "") |
| return hashlib.md5(content.encode()).hexdigest()[:16] |
|
|