junaid17 commited on
Commit
717901c
·
verified ·
1 Parent(s): e4b6894

Update text_engine.py

Browse files
Files changed (1) hide show
  1. text_engine.py +224 -193
text_engine.py CHANGED
@@ -1,194 +1,225 @@
1
- # text_engine.py
2
- import os
3
- import pickle
4
- import logging
5
- from typing import List, Optional
6
- import numpy as np
7
- from sentence_transformers import SentenceTransformer
8
- import faiss
9
- from rank_bm25 import BM25Okapi
10
-
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- class Text_Search_Engine:
16
- def __init__(
17
- self,
18
- base_folder: str = "vector_store",
19
- model_name: str = "sentence-transformers/LaBSE",
20
- index_type: str = "flat",
21
- ):
22
- self.base_folder = base_folder
23
- self.embeddings_folder = os.path.join(base_folder, "embeddings")
24
- self.docs_folder = os.path.join(base_folder, "documents")
25
- os.makedirs(self.embeddings_folder, exist_ok=True)
26
- os.makedirs(self.docs_folder, exist_ok=True)
27
-
28
- self.model = SentenceTransformer(model_name)
29
- self.index: Optional[faiss.Index] = None
30
- self.rows: List[dict] = []
31
- self.texts: List[str] = []
32
- self.bm25: Optional[BM25Okapi] = None
33
- self.index_type = index_type
34
-
35
- # -------------------------
36
- # Index creation utilities
37
- # -------------------------
38
- def _create_index(self, dimension: int, embeddings: np.ndarray):
39
- if self.index_type == "flat":
40
- self.index = faiss.IndexFlatL2(dimension)
41
- elif self.index_type == "ivf":
42
- nlist = max(1, min(256, len(embeddings) // 10))
43
- quantizer = faiss.IndexFlatL2(dimension)
44
- self.index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_L2)
45
- self.index.train(np.array(embeddings).astype("float32"))
46
- elif self.index_type == "hnsw":
47
- self.index = faiss.IndexHNSWFlat(dimension, 32)
48
- else:
49
- raise ValueError(f"Unsupported index type: {self.index_type}")
50
-
51
- def _persist(self):
52
- try:
53
- if self.index is not None:
54
- faiss.write_index(self.index, os.path.join(self.embeddings_folder, "multilingual.index"))
55
- with open(os.path.join(self.docs_folder, "rows.pkl"), "wb") as f:
56
- pickle.dump(self.rows, f)
57
- logger.info("Persisted index and rows to disk.")
58
- except Exception as e:
59
- logger.exception("Failed to persist index/rows: %s", e)
60
-
61
- # -------------------------
62
- # Core operations
63
- # -------------------------
64
- def encode_store(self, rows: List[dict], texts: List[str]):
65
- try:
66
- embeddings = self.model.encode(texts, convert_to_numpy=True)
67
- dimension = embeddings.shape[1]
68
- self._create_index(dimension, embeddings)
69
- self.index.add(np.array(embeddings).astype("float32"))
70
-
71
- self.rows = rows
72
- self.texts = texts
73
- tokenized_corpus = [t.lower().split() for t in texts]
74
- self.bm25 = BM25Okapi(tokenized_corpus)
75
-
76
- self._persist()
77
- logger.info("Index built with %d rows (index_type=%s).", len(rows), self.index_type)
78
- except Exception as e:
79
- logger.exception("Error in encode_store: %s", e)
80
- raise
81
-
82
- def load(self):
83
- try:
84
- index_path = os.path.join(self.embeddings_folder, "multilingual.index")
85
- rows_path = os.path.join(self.docs_folder, "rows.pkl")
86
- if os.path.exists(index_path) and os.path.exists(rows_path):
87
- self.index = faiss.read_index(index_path)
88
- with open(rows_path, "rb") as f:
89
- self.rows = pickle.load(f)
90
- self.texts = [r["_search_text"] for r in self.rows]
91
- tokenized_corpus = [t.lower().split() for t in self.texts]
92
- self.bm25 = BM25Okapi(tokenized_corpus)
93
- logger.info("Loaded index and %d rows from disk.", len(self.rows))
94
- else:
95
- logger.info("No persisted index/rows found.")
96
- except Exception as e:
97
- logger.exception("Error in load: %s", e)
98
- raise
99
-
100
- def add_rows(self, new_rows: List[dict], new_texts: List[str]):
101
- try:
102
- if not new_rows:
103
- return
104
-
105
- new_embeddings = self.model.encode(new_texts, convert_to_numpy=True).astype("float32")
106
- if self.index is None:
107
- self._create_index(new_embeddings.shape[1], new_embeddings)
108
- self.index.add(new_embeddings)
109
- else:
110
- if isinstance(self.index, faiss.IndexIVFFlat) and not self.index.is_trained:
111
- combined = np.vstack([self.model.encode(self.texts, convert_to_numpy=True).astype("float32"), new_embeddings]) if self.texts else new_embeddings
112
- self.index.train(combined)
113
- self.index.add(new_embeddings)
114
-
115
- self.rows.extend(new_rows)
116
- self.texts.extend(new_texts)
117
- tokenized_corpus = [t.lower().split() for t in self.texts]
118
- self.bm25 = BM25Okapi(tokenized_corpus)
119
-
120
- self._persist()
121
- logger.info("Added %d new rows. Total rows: %d", len(new_rows), len(self.rows))
122
- except Exception as e:
123
- logger.exception("Error in add_rows: %s", e)
124
- raise
125
-
126
- # -------------------------
127
- # Search methods
128
- # -------------------------
129
- def search(self, query: str, top_k: int = 3):
130
- try:
131
- if self.index is None:
132
- return []
133
- query_emb = self.model.encode([query], convert_to_numpy=True).astype("float32")
134
- k = min(top_k, len(self.rows))
135
- distances, indices = self.index.search(query_emb, k=k)
136
- results = [
137
- {**self.rows[i], "distance": float(distances[0][j])}
138
- for j, i in enumerate(indices[0])
139
- ]
140
- return sorted(results, key=lambda x: x["distance"])
141
- except Exception as e:
142
- logger.exception("Error in search: %s", e)
143
- return []
144
-
145
- def hybrid_search(self, query: str, top_k: int = 3, alpha: float = 0.5):
146
- try:
147
- if self.index is None or self.bm25 is None:
148
- return []
149
-
150
- query_emb = self.model.encode([query], convert_to_numpy=True).astype("float32")
151
- distances, indices = self.index.search(query_emb, k=len(self.texts))
152
- semantic_scores = {i: 1 / (1 + distances[0][j]) for j, i in enumerate(indices[0])}
153
-
154
- tokenized_query = query.lower().split()
155
- bm25_scores = self.bm25.get_scores(tokenized_query)
156
- lexical_scores = {i: bm25_scores[i] for i in range(len(self.texts))}
157
-
158
- combined = []
159
- for i, row in enumerate(self.rows):
160
- sem = semantic_scores.get(i, 0.0)
161
- lex = lexical_scores.get(i, 0.0)
162
- score = alpha * sem + (1 - alpha) * lex
163
- combined.append({**row, "score": float(score)})
164
-
165
- combined = sorted(combined, key=lambda x: x["score"], reverse=True)
166
- return combined[:top_k]
167
- except Exception as e:
168
- logger.exception("Error in hybrid_search: %s", e)
169
- return []
170
-
171
- # -------------------------
172
- # Utilities
173
- # -------------------------
174
- def clear_vdb(self):
175
- try:
176
- if self.index is not None:
177
- try:
178
- self.index.reset()
179
- except Exception:
180
- self.index = None
181
- self.rows = []
182
- self.texts = []
183
- self.bm25 = None
184
-
185
- index_path = os.path.join(self.embeddings_folder, "multilingual.index")
186
- docs_path = os.path.join(self.docs_folder, "rows.pkl")
187
- if os.path.exists(index_path):
188
- os.remove(index_path)
189
- if os.path.exists(docs_path):
190
- os.remove(docs_path)
191
- logger.info("Cleared vector DB and persisted files.")
192
- except Exception as e:
193
- logger.exception("Error in clear_vdb: %s", e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  raise
 
1
+ # text_engine.py
2
+ import os
3
+ import pickle
4
+ import logging
5
+ from typing import List, Optional
6
+ import numpy as np
7
+ from sentence_transformers import SentenceTransformer
8
+ import faiss
9
+ from rank_bm25 import BM25Okapi
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class Text_Search_Engine:
16
+ def __init__(
17
+ self,
18
+ base_folder: str = "vector_store",
19
+ model_name: str = "sentence-transformers/LaBSE",
20
+ index_type: str = "flat",
21
+ ):
22
+ self.base_folder = base_folder
23
+ self.embeddings_folder = os.path.join(base_folder, "embeddings")
24
+ self.docs_folder = os.path.join(base_folder, "documents")
25
+ os.makedirs(self.embeddings_folder, exist_ok=True)
26
+ os.makedirs(self.docs_folder, exist_ok=True)
27
+
28
+ self.model = SentenceTransformer(model_name)
29
+ self.index: Optional[faiss.Index] = None
30
+ self.rows: List[dict] = []
31
+ self.texts: List[str] = []
32
+ self.bm25: Optional[BM25Okapi] = None
33
+ self.index_type = index_type
34
+
35
+ # -------------------------
36
+ # Index creation utilities
37
+ # -------------------------
38
+ def _create_index(self, dimension: int, embeddings: np.ndarray):
39
+ if self.index_type == "flat":
40
+ self.index = faiss.IndexFlatL2(dimension)
41
+ elif self.index_type == "ivf":
42
+ nlist = max(1, min(256, len(embeddings) // 10))
43
+ quantizer = faiss.IndexFlatL2(dimension)
44
+ self.index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_L2)
45
+ self.index.train(np.array(embeddings).astype("float32"))
46
+ elif self.index_type == "hnsw":
47
+ self.index = faiss.IndexHNSWFlat(dimension, 32)
48
+ else:
49
+ raise ValueError(f"Unsupported index type: {self.index_type}")
50
+
51
+ def _persist(self):
52
+ try:
53
+ if self.index is not None:
54
+ faiss.write_index(self.index, os.path.join(self.embeddings_folder, "multilingual.index"))
55
+ with open(os.path.join(self.docs_folder, "rows.pkl"), "wb") as f:
56
+ pickle.dump(self.rows, f)
57
+ logger.info("Persisted index and rows to disk.")
58
+ except Exception as e:
59
+ logger.exception("Failed to persist index/rows: %s", e)
60
+
61
+ # -------------------------
62
+ # Core operations
63
+ # -------------------------
64
+ def encode_store(self, rows: List[dict], texts: List[str]):
65
+ try:
66
+ embeddings = self.model.encode(texts, convert_to_numpy=True)
67
+ dimension = embeddings.shape[1]
68
+ self._create_index(dimension, embeddings)
69
+ self.index.add(np.array(embeddings).astype("float32"))
70
+
71
+ self.rows = rows
72
+ self.texts = texts
73
+ tokenized_corpus = [t.lower().split() for t in texts]
74
+ self.bm25 = BM25Okapi(tokenized_corpus)
75
+
76
+ self._persist()
77
+ logger.info("Index built with %d rows (index_type=%s).", len(rows), self.index_type)
78
+ except Exception as e:
79
+ logger.exception("Error in encode_store: %s", e)
80
+ raise
81
+
82
+ def load(self):
83
+ try:
84
+ index_path = os.path.join(self.embeddings_folder, "multilingual.index")
85
+ rows_path = os.path.join(self.docs_folder, "rows.pkl")
86
+ if os.path.exists(index_path) and os.path.exists(rows_path):
87
+ self.index = faiss.read_index(index_path)
88
+ with open(rows_path, "rb") as f:
89
+ self.rows = pickle.load(f)
90
+ self.texts = [r["_search_text"] for r in self.rows]
91
+ tokenized_corpus = [t.lower().split() for t in self.texts]
92
+ self.bm25 = BM25Okapi(tokenized_corpus)
93
+ logger.info("Loaded index and %d rows from disk.", len(self.rows))
94
+ else:
95
+ logger.info("No persisted index/rows found.")
96
+ except Exception as e:
97
+ logger.exception("Error in load: %s", e)
98
+ raise
99
+
100
+ def add_rows(self, new_rows: List[dict], new_texts: List[str]):
101
+ try:
102
+ if not new_rows:
103
+ return
104
+
105
+ new_embeddings = self.model.encode(new_texts, convert_to_numpy=True).astype("float32")
106
+ if self.index is None:
107
+ self._create_index(new_embeddings.shape[1], new_embeddings)
108
+ self.index.add(new_embeddings)
109
+ else:
110
+ if isinstance(self.index, faiss.IndexIVFFlat) and not self.index.is_trained:
111
+ combined = np.vstack([self.model.encode(self.texts, convert_to_numpy=True).astype("float32"), new_embeddings]) if self.texts else new_embeddings
112
+ self.index.train(combined)
113
+ self.index.add(new_embeddings)
114
+
115
+ self.rows.extend(new_rows)
116
+ self.texts.extend(new_texts)
117
+ tokenized_corpus = [t.lower().split() for t in self.texts]
118
+ self.bm25 = BM25Okapi(tokenized_corpus)
119
+
120
+ self._persist()
121
+ logger.info("Added %d new rows. Total rows: %d", len(new_rows), len(self.rows))
122
+ except Exception as e:
123
+ logger.exception("Error in add_rows: %s", e)
124
+ raise
125
+
126
+ # -------------------------
127
+ # Search methods
128
+ # -------------------------
129
+ def search(self, query: str, top_k: int = 3):
130
+ try:
131
+ if self.index is None:
132
+ return []
133
+ query_emb = self.model.encode([query], convert_to_numpy=True).astype("float32")
134
+ k = min(top_k, len(self.rows))
135
+ distances, indices = self.index.search(query_emb, k=k)
136
+ results = [
137
+ {**self.rows[i], "distance": float(distances[0][j])}
138
+ for j, i in enumerate(indices[0])
139
+ ]
140
+ return sorted(results, key=lambda x: x["distance"])
141
+ except Exception as e:
142
+ logger.exception("Error in search: %s", e)
143
+ return []
144
+
145
+ def hybrid_search(self, query: str, top_k: int = 3, alpha: float = 0.5):
146
+ try:
147
+ if self.index is None or self.bm25 is None:
148
+ return []
149
+
150
+ # 🔹 Step 1: Encode query
151
+ query_emb = self.model.encode([query], convert_to_numpy=True).astype("float32")
152
+
153
+ # 🔹 Step 2: Retrieve top candidates (IMPORTANT)
154
+ retrieve_k = min(20, len(self.texts)) # candidate pool
155
+ distances, indices = self.index.search(query_emb, k=retrieve_k)
156
+
157
+ candidate_ids = indices[0]
158
+
159
+ # 🔹 Step 3: Semantic scores (convert distance → similarity)
160
+ sem_scores = {}
161
+ for j, i in enumerate(candidate_ids):
162
+ sim = 1 / (1 + distances[0][j])
163
+ sem_scores[i] = sim
164
+
165
+ # 🔹 Step 4: BM25 scores (only for candidates)
166
+ tokenized_query = query.lower().split()
167
+ bm25_scores = self.bm25.get_scores(tokenized_query)
168
+
169
+ lex_scores = {i: bm25_scores[i] for i in candidate_ids}
170
+
171
+ # 🔹 Step 5: NORMALIZATION (CRITICAL)
172
+ def normalize(scores_dict):
173
+ vals = list(scores_dict.values())
174
+ if not vals:
175
+ return scores_dict
176
+ min_v, max_v = min(vals), max(vals)
177
+ if max_v - min_v == 0:
178
+ return {k: 0.0 for k in scores_dict}
179
+ return {k: (v - min_v) / (max_v - min_v) for k, v in scores_dict.items()}
180
+
181
+ sem_scores = normalize(sem_scores)
182
+ lex_scores = normalize(lex_scores)
183
+
184
+ # 🔹 Step 6: Combine scores
185
+ combined = []
186
+ for i in candidate_ids:
187
+ sem = sem_scores.get(i, 0.0)
188
+ lex = lex_scores.get(i, 0.0)
189
+ score = alpha * sem + (1 - alpha) * lex
190
+
191
+ combined.append({**self.rows[i], "score": float(score)})
192
+
193
+ # 🔹 Step 7: Sort and return
194
+ combined = sorted(combined, key=lambda x: x["score"], reverse=True)
195
+
196
+ return combined[:top_k]
197
+
198
+ except Exception as e:
199
+ logger.exception("Error in hybrid_search: %s", e)
200
+ return []
201
+
202
+ # -------------------------
203
+ # Utilities
204
+ # -------------------------
205
+ def clear_vdb(self):
206
+ try:
207
+ if self.index is not None:
208
+ try:
209
+ self.index.reset()
210
+ except Exception:
211
+ self.index = None
212
+ self.rows = []
213
+ self.texts = []
214
+ self.bm25 = None
215
+
216
+ index_path = os.path.join(self.embeddings_folder, "multilingual.index")
217
+ docs_path = os.path.join(self.docs_folder, "rows.pkl")
218
+ if os.path.exists(index_path):
219
+ os.remove(index_path)
220
+ if os.path.exists(docs_path):
221
+ os.remove(docs_path)
222
+ logger.info("Cleared vector DB and persisted files.")
223
+ except Exception as e:
224
+ logger.exception("Error in clear_vdb: %s", e)
225
  raise