telegram-analytics / hybrid_search.py
rottg's picture
Update code
85ff768 verified
"""
Hybrid Search - Combines Vector Search, BM25, and Query Expansion
This provides much better search for chat data by:
1. Chunk-based vector search (captures context)
2. BM25 keyword search (finds exact matches)
3. Query expansion (handles variations)
"""
import sqlite3
import numpy as np
import pickle
import re
import os
from typing import List, Dict, Any, Optional
# Try importing sentence-transformers
try:
from sentence_transformers import SentenceTransformer
HAS_TRANSFORMERS = True
except ImportError:
HAS_TRANSFORMERS = False
# Try importing BM25
try:
from rank_bm25 import BM25Okapi
HAS_BM25 = True
except ImportError:
HAS_BM25 = False
class HybridSearch:
"""
Hybrid search combining:
- Chunk-based vector search (conversation context)
- BM25 keyword search (exact matches)
- Query expansion (synonyms, variations)
"""
def __init__(self,
messages_db: str = 'telegram.db',
chunk_embeddings_db: str = 'chunk_embeddings.db',
bm25_index_path: str = 'bm25_index.pkl',
single_embeddings_db: str = 'embeddings.db'):
self.messages_db = messages_db
self.chunk_embeddings_db = chunk_embeddings_db
self.bm25_index_path = bm25_index_path
self.single_embeddings_db = single_embeddings_db
# Lazy-loaded components
self.model = None
self.chunk_embeddings = None
self.chunk_data = None
self.bm25 = None
self.bm25_message_ids = None
self.single_embeddings = None
self.single_message_ids = None
def _load_model(self):
"""Load the embedding model."""
if self.model is not None:
return
if not HAS_TRANSFORMERS:
raise RuntimeError("sentence-transformers not installed")
print("Loading embedding model...")
self.model = SentenceTransformer('intfloat/multilingual-e5-large')
print("Model loaded!")
def _load_chunk_embeddings(self):
"""Load chunk embeddings."""
if self.chunk_embeddings is not None:
return True
if not os.path.exists(self.chunk_embeddings_db):
print(f"Chunk embeddings not found: {self.chunk_embeddings_db}")
return False
print(f"Loading chunk embeddings from {self.chunk_embeddings_db}...")
conn = sqlite3.connect(self.chunk_embeddings_db)
# Check if chunk_type column exists (for backwards compatibility)
cursor = conn.execute("PRAGMA table_info(chunk_embeddings)")
columns = [col[1] for col in cursor.fetchall()]
has_type = 'chunk_type' in columns
if has_type:
rows = conn.execute("""
SELECT chunk_id, chunk_type, text, message_ids, anchor_message_id, embedding
FROM chunk_embeddings
""").fetchall()
else:
rows = conn.execute("""
SELECT chunk_id, 'window' as chunk_type, text, message_ids, anchor_message_id, embedding
FROM chunk_embeddings
""").fetchall()
conn.close()
if not rows:
return False
import json
self.chunk_data = []
emb_list = []
for row in rows:
chunk_id, chunk_type, text, msg_ids_json, anchor_id, emb_blob = row
emb = np.frombuffer(emb_blob, dtype=np.float32)
emb_list.append(emb)
self.chunk_data.append({
'chunk_id': chunk_id,
'chunk_type': chunk_type,
'text': text,
'message_ids': json.loads(msg_ids_json),
'anchor_message_id': anchor_id
})
self.chunk_embeddings = np.vstack(emb_list)
# Normalize
norms = np.linalg.norm(self.chunk_embeddings, axis=1, keepdims=True)
self.chunk_embeddings = self.chunk_embeddings / np.where(norms == 0, 1, norms)
print(f"Loaded {len(self.chunk_data)} chunk embeddings")
return True
def _load_single_embeddings(self):
"""Load single-message embeddings (fallback)."""
if self.single_embeddings is not None:
return True
if not os.path.exists(self.single_embeddings_db):
return False
print(f"Loading single embeddings from {self.single_embeddings_db}...")
conn = sqlite3.connect(self.single_embeddings_db)
rows = conn.execute("""
SELECT message_id, embedding FROM embeddings
""").fetchall()
conn.close()
if not rows:
return False
self.single_message_ids = []
emb_list = []
for row in rows:
msg_id, emb_blob = row
emb = np.frombuffer(emb_blob, dtype=np.float32)
emb_list.append(emb)
self.single_message_ids.append(msg_id)
self.single_embeddings = np.vstack(emb_list)
norms = np.linalg.norm(self.single_embeddings, axis=1, keepdims=True)
self.single_embeddings = self.single_embeddings / np.where(norms == 0, 1, norms)
print(f"Loaded {len(self.single_message_ids)} single embeddings")
return True
def _load_bm25(self):
"""Load BM25 index."""
if self.bm25 is not None:
return True
if not os.path.exists(self.bm25_index_path):
print(f"BM25 index not found: {self.bm25_index_path}")
return False
print(f"Loading BM25 index from {self.bm25_index_path}...")
with open(self.bm25_index_path, 'rb') as f:
data = pickle.load(f)
self.bm25 = data['bm25']
self.bm25_message_ids = data['message_ids']
print(f"Loaded BM25 index with {len(self.bm25_message_ids)} documents")
return True
def expand_query(self, query: str) -> List[str]:
"""
Expand query with variations.
Returns list of query variations to search.
"""
queries = [query]
# Hebrew question word expansions
expansions = {
'איפה': ['איפה', 'היכן', 'מיקום', 'כתובת', 'עיר'],
'מתי': ['מתי', 'באיזה תאריך', 'מועד', 'זמן'],
'מי': ['מי', 'מיהו', 'מיהי', 'שם'],
'כמה': ['כמה', 'מספר', 'כמות'],
'למה': ['למה', 'מדוע', 'סיבה'],
'גר': ['גר', 'גרה', 'מתגורר', 'מתגוררת', 'גרים'],
'עובד': ['עובד', 'עובדת', 'עובדים', 'מועסק', 'עבודה'],
}
# Add expanded variations
for word, synonyms in expansions.items():
if word in query:
for syn in synonyms:
if syn != word:
expanded = query.replace(word, syn)
if expanded not in queries:
queries.append(expanded)
return queries[:5] # Limit to 5 variations
def search_chunks(self, query: str, limit: int = 20) -> List[Dict]:
"""Search using chunk embeddings (context-aware)."""
if not self._load_chunk_embeddings():
return []
self._load_model()
# Encode query with e5 prefix
query_emb = self.model.encode([f"query: {query}"], convert_to_numpy=True)[0]
query_norm = query_emb / np.linalg.norm(query_emb)
# Compute similarities
similarities = np.dot(self.chunk_embeddings, query_norm)
# Get top results
top_indices = np.argsort(similarities)[::-1][:limit]
results = []
for idx in top_indices:
score = float(similarities[idx])
chunk = self.chunk_data[idx]
results.append({
'type': 'chunk',
'chunk_type': chunk.get('chunk_type', 'window'), # 'thread' or 'window'
'chunk_id': chunk['chunk_id'],
'text': chunk['text'],
'message_ids': chunk['message_ids'],
'anchor_message_id': chunk['anchor_message_id'],
'score': score
})
return results
def search_bm25(self, query: str, limit: int = 20) -> List[Dict]:
"""Search using BM25 (keyword-based)."""
if not self._load_bm25():
return []
# Tokenize query
query_tokens = re.findall(r'\w+', query.lower())
# Get BM25 scores
scores = self.bm25.get_scores(query_tokens)
# Get top results
top_indices = np.argsort(scores)[::-1][:limit]
results = []
for idx in top_indices:
score = float(scores[idx])
if score > 0:
results.append({
'type': 'bm25',
'message_id': self.bm25_message_ids[idx],
'score': score
})
return results
def search_single(self, query: str, limit: int = 20) -> List[Dict]:
"""Search using single-message embeddings (fallback)."""
if not self._load_single_embeddings():
return []
self._load_model()
query_emb = self.model.encode([f"query: {query}"], convert_to_numpy=True)[0]
query_norm = query_emb / np.linalg.norm(query_emb)
similarities = np.dot(self.single_embeddings, query_norm)
top_indices = np.argsort(similarities)[::-1][:limit]
results = []
for idx in top_indices:
score = float(similarities[idx])
results.append({
'type': 'single',
'message_id': self.single_message_ids[idx],
'score': score
})
return results
def hybrid_search(self, query: str, limit: int = 20,
vector_weight: float = 0.6,
bm25_weight: float = 0.4,
use_expansion: bool = True) -> List[Dict]:
"""
Hybrid search combining vector and BM25.
Args:
query: Search query
limit: Max results
vector_weight: Weight for vector search (0-1)
bm25_weight: Weight for BM25 search (0-1)
use_expansion: Whether to expand query
Returns:
Combined search results
"""
all_message_scores = {}
# Get expanded queries
queries = self.expand_query(query) if use_expansion else [query]
# Search with each query variation
for q in queries:
# Chunk/Vector search
chunk_results = self.search_chunks(q, limit=limit * 2)
for r in chunk_results:
for msg_id in r['message_ids']:
if msg_id not in all_message_scores:
all_message_scores[msg_id] = {'vector': 0, 'bm25': 0, 'chunk_text': None}
# Use max score across message appearances
all_message_scores[msg_id]['vector'] = max(
all_message_scores[msg_id]['vector'],
r['score'] * vector_weight
)
if all_message_scores[msg_id]['chunk_text'] is None:
all_message_scores[msg_id]['chunk_text'] = r['text']
# BM25 search
bm25_results = self.search_bm25(q, limit=limit * 2)
for r in bm25_results:
msg_id = r['message_id']
if msg_id not in all_message_scores:
all_message_scores[msg_id] = {'vector': 0, 'bm25': 0, 'chunk_text': None}
all_message_scores[msg_id]['bm25'] = max(
all_message_scores[msg_id]['bm25'],
r['score'] * bm25_weight / 10 # Normalize BM25 scores
)
# Combine scores
combined = []
for msg_id, scores in all_message_scores.items():
total_score = scores['vector'] + scores['bm25']
combined.append({
'message_id': msg_id,
'score': total_score,
'vector_score': scores['vector'],
'bm25_score': scores['bm25'],
'chunk_text': scores['chunk_text']
})
# Sort by combined score
combined.sort(key=lambda x: x['score'], reverse=True)
return combined[:limit]
def search_with_context(self, query: str, limit: int = 20,
context_window: int = 3) -> List[Dict]:
"""
Search and return results with surrounding context.
Args:
query: Search query
limit: Max results
context_window: Messages before/after to include
Returns:
Results with full context
"""
# Get hybrid search results
results = self.hybrid_search(query, limit=limit)
if not results:
return []
# Get full context from DB
conn = sqlite3.connect(self.messages_db)
conn.row_factory = sqlite3.Row
enriched = []
for r in results:
msg_id = r['message_id']
# Get the message
msg = conn.execute(
"SELECT * FROM messages WHERE id = ?", (msg_id,)
).fetchone()
if not msg:
continue
# Get surrounding messages
context_before = conn.execute("""
SELECT id, date, from_name, text_plain FROM messages
WHERE date_unixtime < (SELECT date_unixtime FROM messages WHERE id = ?)
ORDER BY date_unixtime DESC LIMIT ?
""", (msg_id, context_window)).fetchall()
context_after = conn.execute("""
SELECT id, date, from_name, text_plain FROM messages
WHERE date_unixtime > (SELECT date_unixtime FROM messages WHERE id = ?)
ORDER BY date_unixtime ASC LIMIT ?
""", (msg_id, context_window)).fetchall()
enriched.append({
'message_id': msg_id,
'score': r['score'],
'message': {
'id': msg['id'],
'date': msg['date'],
'from_name': msg['from_name'],
'text': msg['text_plain']
},
'context_before': [dict(m) for m in reversed(context_before)],
'context_after': [dict(m) for m in context_after],
'chunk_text': r.get('chunk_text')
})
conn.close()
return enriched
def stats(self) -> Dict[str, Any]:
"""Get search index statistics."""
stats = {
'chunks_available': os.path.exists(self.chunk_embeddings_db),
'bm25_available': os.path.exists(self.bm25_index_path),
'single_embeddings_available': os.path.exists(self.single_embeddings_db),
}
if stats['chunks_available']:
conn = sqlite3.connect(self.chunk_embeddings_db)
stats['chunk_count'] = conn.execute(
"SELECT COUNT(*) FROM chunk_embeddings"
).fetchone()[0]
conn.close()
if stats['single_embeddings_available']:
conn = sqlite3.connect(self.single_embeddings_db)
stats['single_embedding_count'] = conn.execute(
"SELECT COUNT(*) FROM embeddings"
).fetchone()[0]
conn.close()
return stats
# Singleton instance
_hybrid_search = None
def get_hybrid_search() -> HybridSearch:
"""Get or create hybrid search instance."""
global _hybrid_search
if _hybrid_search is None:
_hybrid_search = HybridSearch()
return _hybrid_search
# CLI for testing
if __name__ == '__main__':
import sys
if len(sys.argv) < 2:
print("Usage: python hybrid_search.py 'search query'")
print("\nStats:")
hs = get_hybrid_search()
print(hs.stats())
sys.exit(0)
query = ' '.join(sys.argv[1:])
hs = get_hybrid_search()
print(f"\n=== Searching: {query} ===\n")
# Show expanded queries
expanded = hs.expand_query(query)
print(f"Expanded queries: {expanded}\n")
# Search
results = hs.search_with_context(query, limit=5)
for i, r in enumerate(results, 1):
print(f"--- Result {i} (score: {r['score']:.3f}) ---")
print(f"From: {r['message']['from_name']}")
print(f"Date: {r['message']['date']}")
print(f"Text: {r['message']['text'][:200]}...")
if r['context_before']:
print(f"\nContext before:")
for ctx in r['context_before']:
print(f" [{ctx['from_name']}] {ctx['text_plain'][:100]}...")
if r['context_after']:
print(f"\nContext after:")
for ctx in r['context_after']:
print(f" [{ctx['from_name']}] {ctx['text_plain'][:100]}...")
print()