""" Hybrid Search - Combines Vector Search, BM25, and Query Expansion This provides much better search for chat data by: 1. Chunk-based vector search (captures context) 2. BM25 keyword search (finds exact matches) 3. Query expansion (handles variations) """ import sqlite3 import numpy as np import pickle import re import os from typing import List, Dict, Any, Optional # Try importing sentence-transformers try: from sentence_transformers import SentenceTransformer HAS_TRANSFORMERS = True except ImportError: HAS_TRANSFORMERS = False # Try importing BM25 try: from rank_bm25 import BM25Okapi HAS_BM25 = True except ImportError: HAS_BM25 = False class HybridSearch: """ Hybrid search combining: - Chunk-based vector search (conversation context) - BM25 keyword search (exact matches) - Query expansion (synonyms, variations) """ def __init__(self, messages_db: str = 'telegram.db', chunk_embeddings_db: str = 'chunk_embeddings.db', bm25_index_path: str = 'bm25_index.pkl', single_embeddings_db: str = 'embeddings.db'): self.messages_db = messages_db self.chunk_embeddings_db = chunk_embeddings_db self.bm25_index_path = bm25_index_path self.single_embeddings_db = single_embeddings_db # Lazy-loaded components self.model = None self.chunk_embeddings = None self.chunk_data = None self.bm25 = None self.bm25_message_ids = None self.single_embeddings = None self.single_message_ids = None def _load_model(self): """Load the embedding model.""" if self.model is not None: return if not HAS_TRANSFORMERS: raise RuntimeError("sentence-transformers not installed") print("Loading embedding model...") self.model = SentenceTransformer('intfloat/multilingual-e5-large') print("Model loaded!") def _load_chunk_embeddings(self): """Load chunk embeddings.""" if self.chunk_embeddings is not None: return True if not os.path.exists(self.chunk_embeddings_db): print(f"Chunk embeddings not found: {self.chunk_embeddings_db}") return False print(f"Loading chunk embeddings from {self.chunk_embeddings_db}...") conn = sqlite3.connect(self.chunk_embeddings_db) # Check if chunk_type column exists (for backwards compatibility) cursor = conn.execute("PRAGMA table_info(chunk_embeddings)") columns = [col[1] for col in cursor.fetchall()] has_type = 'chunk_type' in columns if has_type: rows = conn.execute(""" SELECT chunk_id, chunk_type, text, message_ids, anchor_message_id, embedding FROM chunk_embeddings """).fetchall() else: rows = conn.execute(""" SELECT chunk_id, 'window' as chunk_type, text, message_ids, anchor_message_id, embedding FROM chunk_embeddings """).fetchall() conn.close() if not rows: return False import json self.chunk_data = [] emb_list = [] for row in rows: chunk_id, chunk_type, text, msg_ids_json, anchor_id, emb_blob = row emb = np.frombuffer(emb_blob, dtype=np.float32) emb_list.append(emb) self.chunk_data.append({ 'chunk_id': chunk_id, 'chunk_type': chunk_type, 'text': text, 'message_ids': json.loads(msg_ids_json), 'anchor_message_id': anchor_id }) self.chunk_embeddings = np.vstack(emb_list) # Normalize norms = np.linalg.norm(self.chunk_embeddings, axis=1, keepdims=True) self.chunk_embeddings = self.chunk_embeddings / np.where(norms == 0, 1, norms) print(f"Loaded {len(self.chunk_data)} chunk embeddings") return True def _load_single_embeddings(self): """Load single-message embeddings (fallback).""" if self.single_embeddings is not None: return True if not os.path.exists(self.single_embeddings_db): return False print(f"Loading single embeddings from {self.single_embeddings_db}...") conn = sqlite3.connect(self.single_embeddings_db) rows = conn.execute(""" SELECT message_id, embedding FROM embeddings """).fetchall() conn.close() if not rows: return False self.single_message_ids = [] emb_list = [] for row in rows: msg_id, emb_blob = row emb = np.frombuffer(emb_blob, dtype=np.float32) emb_list.append(emb) self.single_message_ids.append(msg_id) self.single_embeddings = np.vstack(emb_list) norms = np.linalg.norm(self.single_embeddings, axis=1, keepdims=True) self.single_embeddings = self.single_embeddings / np.where(norms == 0, 1, norms) print(f"Loaded {len(self.single_message_ids)} single embeddings") return True def _load_bm25(self): """Load BM25 index.""" if self.bm25 is not None: return True if not os.path.exists(self.bm25_index_path): print(f"BM25 index not found: {self.bm25_index_path}") return False print(f"Loading BM25 index from {self.bm25_index_path}...") with open(self.bm25_index_path, 'rb') as f: data = pickle.load(f) self.bm25 = data['bm25'] self.bm25_message_ids = data['message_ids'] print(f"Loaded BM25 index with {len(self.bm25_message_ids)} documents") return True def expand_query(self, query: str) -> List[str]: """ Expand query with variations. Returns list of query variations to search. """ queries = [query] # Hebrew question word expansions expansions = { 'איפה': ['איפה', 'היכן', 'מיקום', 'כתובת', 'עיר'], 'מתי': ['מתי', 'באיזה תאריך', 'מועד', 'זמן'], 'מי': ['מי', 'מיהו', 'מיהי', 'שם'], 'כמה': ['כמה', 'מספר', 'כמות'], 'למה': ['למה', 'מדוע', 'סיבה'], 'גר': ['גר', 'גרה', 'מתגורר', 'מתגוררת', 'גרים'], 'עובד': ['עובד', 'עובדת', 'עובדים', 'מועסק', 'עבודה'], } # Add expanded variations for word, synonyms in expansions.items(): if word in query: for syn in synonyms: if syn != word: expanded = query.replace(word, syn) if expanded not in queries: queries.append(expanded) return queries[:5] # Limit to 5 variations def search_chunks(self, query: str, limit: int = 20) -> List[Dict]: """Search using chunk embeddings (context-aware).""" if not self._load_chunk_embeddings(): return [] self._load_model() # Encode query with e5 prefix query_emb = self.model.encode([f"query: {query}"], convert_to_numpy=True)[0] query_norm = query_emb / np.linalg.norm(query_emb) # Compute similarities similarities = np.dot(self.chunk_embeddings, query_norm) # Get top results top_indices = np.argsort(similarities)[::-1][:limit] results = [] for idx in top_indices: score = float(similarities[idx]) chunk = self.chunk_data[idx] results.append({ 'type': 'chunk', 'chunk_type': chunk.get('chunk_type', 'window'), # 'thread' or 'window' 'chunk_id': chunk['chunk_id'], 'text': chunk['text'], 'message_ids': chunk['message_ids'], 'anchor_message_id': chunk['anchor_message_id'], 'score': score }) return results def search_bm25(self, query: str, limit: int = 20) -> List[Dict]: """Search using BM25 (keyword-based).""" if not self._load_bm25(): return [] # Tokenize query query_tokens = re.findall(r'\w+', query.lower()) # Get BM25 scores scores = self.bm25.get_scores(query_tokens) # Get top results top_indices = np.argsort(scores)[::-1][:limit] results = [] for idx in top_indices: score = float(scores[idx]) if score > 0: results.append({ 'type': 'bm25', 'message_id': self.bm25_message_ids[idx], 'score': score }) return results def search_single(self, query: str, limit: int = 20) -> List[Dict]: """Search using single-message embeddings (fallback).""" if not self._load_single_embeddings(): return [] self._load_model() query_emb = self.model.encode([f"query: {query}"], convert_to_numpy=True)[0] query_norm = query_emb / np.linalg.norm(query_emb) similarities = np.dot(self.single_embeddings, query_norm) top_indices = np.argsort(similarities)[::-1][:limit] results = [] for idx in top_indices: score = float(similarities[idx]) results.append({ 'type': 'single', 'message_id': self.single_message_ids[idx], 'score': score }) return results def hybrid_search(self, query: str, limit: int = 20, vector_weight: float = 0.6, bm25_weight: float = 0.4, use_expansion: bool = True) -> List[Dict]: """ Hybrid search combining vector and BM25. Args: query: Search query limit: Max results vector_weight: Weight for vector search (0-1) bm25_weight: Weight for BM25 search (0-1) use_expansion: Whether to expand query Returns: Combined search results """ all_message_scores = {} # Get expanded queries queries = self.expand_query(query) if use_expansion else [query] # Search with each query variation for q in queries: # Chunk/Vector search chunk_results = self.search_chunks(q, limit=limit * 2) for r in chunk_results: for msg_id in r['message_ids']: if msg_id not in all_message_scores: all_message_scores[msg_id] = {'vector': 0, 'bm25': 0, 'chunk_text': None} # Use max score across message appearances all_message_scores[msg_id]['vector'] = max( all_message_scores[msg_id]['vector'], r['score'] * vector_weight ) if all_message_scores[msg_id]['chunk_text'] is None: all_message_scores[msg_id]['chunk_text'] = r['text'] # BM25 search bm25_results = self.search_bm25(q, limit=limit * 2) for r in bm25_results: msg_id = r['message_id'] if msg_id not in all_message_scores: all_message_scores[msg_id] = {'vector': 0, 'bm25': 0, 'chunk_text': None} all_message_scores[msg_id]['bm25'] = max( all_message_scores[msg_id]['bm25'], r['score'] * bm25_weight / 10 # Normalize BM25 scores ) # Combine scores combined = [] for msg_id, scores in all_message_scores.items(): total_score = scores['vector'] + scores['bm25'] combined.append({ 'message_id': msg_id, 'score': total_score, 'vector_score': scores['vector'], 'bm25_score': scores['bm25'], 'chunk_text': scores['chunk_text'] }) # Sort by combined score combined.sort(key=lambda x: x['score'], reverse=True) return combined[:limit] def search_with_context(self, query: str, limit: int = 20, context_window: int = 3) -> List[Dict]: """ Search and return results with surrounding context. Args: query: Search query limit: Max results context_window: Messages before/after to include Returns: Results with full context """ # Get hybrid search results results = self.hybrid_search(query, limit=limit) if not results: return [] # Get full context from DB conn = sqlite3.connect(self.messages_db) conn.row_factory = sqlite3.Row enriched = [] for r in results: msg_id = r['message_id'] # Get the message msg = conn.execute( "SELECT * FROM messages WHERE id = ?", (msg_id,) ).fetchone() if not msg: continue # Get surrounding messages context_before = conn.execute(""" SELECT id, date, from_name, text_plain FROM messages WHERE date_unixtime < (SELECT date_unixtime FROM messages WHERE id = ?) ORDER BY date_unixtime DESC LIMIT ? """, (msg_id, context_window)).fetchall() context_after = conn.execute(""" SELECT id, date, from_name, text_plain FROM messages WHERE date_unixtime > (SELECT date_unixtime FROM messages WHERE id = ?) ORDER BY date_unixtime ASC LIMIT ? """, (msg_id, context_window)).fetchall() enriched.append({ 'message_id': msg_id, 'score': r['score'], 'message': { 'id': msg['id'], 'date': msg['date'], 'from_name': msg['from_name'], 'text': msg['text_plain'] }, 'context_before': [dict(m) for m in reversed(context_before)], 'context_after': [dict(m) for m in context_after], 'chunk_text': r.get('chunk_text') }) conn.close() return enriched def stats(self) -> Dict[str, Any]: """Get search index statistics.""" stats = { 'chunks_available': os.path.exists(self.chunk_embeddings_db), 'bm25_available': os.path.exists(self.bm25_index_path), 'single_embeddings_available': os.path.exists(self.single_embeddings_db), } if stats['chunks_available']: conn = sqlite3.connect(self.chunk_embeddings_db) stats['chunk_count'] = conn.execute( "SELECT COUNT(*) FROM chunk_embeddings" ).fetchone()[0] conn.close() if stats['single_embeddings_available']: conn = sqlite3.connect(self.single_embeddings_db) stats['single_embedding_count'] = conn.execute( "SELECT COUNT(*) FROM embeddings" ).fetchone()[0] conn.close() return stats # Singleton instance _hybrid_search = None def get_hybrid_search() -> HybridSearch: """Get or create hybrid search instance.""" global _hybrid_search if _hybrid_search is None: _hybrid_search = HybridSearch() return _hybrid_search # CLI for testing if __name__ == '__main__': import sys if len(sys.argv) < 2: print("Usage: python hybrid_search.py 'search query'") print("\nStats:") hs = get_hybrid_search() print(hs.stats()) sys.exit(0) query = ' '.join(sys.argv[1:]) hs = get_hybrid_search() print(f"\n=== Searching: {query} ===\n") # Show expanded queries expanded = hs.expand_query(query) print(f"Expanded queries: {expanded}\n") # Search results = hs.search_with_context(query, limit=5) for i, r in enumerate(results, 1): print(f"--- Result {i} (score: {r['score']:.3f}) ---") print(f"From: {r['message']['from_name']}") print(f"Date: {r['message']['date']}") print(f"Text: {r['message']['text'][:200]}...") if r['context_before']: print(f"\nContext before:") for ctx in r['context_before']: print(f" [{ctx['from_name']}] {ctx['text_plain'][:100]}...") if r['context_after']: print(f"\nContext after:") for ctx in r['context_after']: print(f" [{ctx['from_name']}] {ctx['text_plain'][:100]}...") print()