junaid17 commited on
Commit
c46b826
·
verified ·
1 Parent(s): fa15a30

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +163 -0
  2. requirements.txt +14 -0
  3. text_engine.py +194 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ from fastapi import FastAPI, HTTPException, status, File, UploadFile, Form, Query
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from typing import Optional
5
+ import pandas as pd
6
+ import io
7
+ import os
8
+ from text_engine import Text_Search_Engine
9
+
10
+ app = FastAPI(title="CortexSearch", version="1.0", description="A flexible text search API with multiple FAISS index types and BM25 support.")
11
+
12
+ # Choose default index_type here: "flat", "ivf", or "hnsw"
13
+ store = Text_Search_Engine(index_type=os.getenv("INDEX_TYPE", "flat"))
14
+ try:
15
+ store.load()
16
+ except Exception:
17
+ pass
18
+
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+
28
+ @app.get("/")
29
+ async def root():
30
+ return {"message": "Welcome to the Flexible Text Intelligence API"}
31
+
32
+
33
+ # -------------------------
34
+ # Column preview endpoint
35
+ # -------------------------
36
+ @app.post("/list_columns")
37
+ async def list_columns(file: UploadFile = File(...)):
38
+ """
39
+ Upload a CSV and get available columns back.
40
+ Useful to preview before choosing columns to index.
41
+ """
42
+ try:
43
+ contents = await file.read()
44
+ df = pd.read_csv(io.BytesIO(contents))
45
+ return {"available_columns": list(df.columns)}
46
+ except Exception as e:
47
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
48
+
49
+
50
+ # -------------------------
51
+ # Health check endpoint
52
+ # -------------------------
53
+ @app.get("/health")
54
+ async def health():
55
+ return {"status": "ok", "rows_indexed": len(store.rows), "index_type": store.index_type}
56
+
57
+
58
+ # -------------------------
59
+ # Upload CSV (build fresh index)
60
+ # -------------------------
61
+ @app.post("/upload_csv")
62
+ async def upload_csv(file: UploadFile = File(...), columns: str = Form(...), index_type: Optional[str] = Form(None)):
63
+ #Upload CSV and specify columns (comma-separated) to combine into searchable text.
64
+ #Optional form field 'index_type' can be 'flat', 'ivf', or 'hnsw' to override engine default.
65
+ try:
66
+ contents = await file.read()
67
+ df = pd.read_csv(io.BytesIO(contents))
68
+
69
+ column_list = [c.strip() for c in columns.split(",") if c.strip()]
70
+ # Validate
71
+ for col in column_list:
72
+ if col not in df.columns:
73
+ return {
74
+ "status": "error",
75
+ "detail": f"Column '{col}' not found.",
76
+ "available_columns": list(df.columns),
77
+ }
78
+
79
+ rows = df.dropna(subset=column_list).to_dict(orient="records")
80
+ for r in rows:
81
+ r["_search_text"] = " ".join(str(r[col]) for col in column_list if r.get(col) is not None)
82
+
83
+ texts = [r["_search_text"] for r in rows]
84
+
85
+ if index_type:
86
+ store.index_type = index_type
87
+
88
+ store.encode_store(rows, texts)
89
+ return {"status": "success", "count": len(rows), "used_columns": column_list, "index_type": store.index_type}
90
+ except Exception as e:
91
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
92
+
93
+
94
+ # -------------------------
95
+ # Add CSV (append new rows)
96
+ # -------------------------
97
+ @app.post("/add_csv")
98
+ async def add_csv(file: UploadFile = File(...), columns: str = Form(...)):
99
+ try:
100
+ contents = await file.read()
101
+ df = pd.read_csv(io.BytesIO(contents))
102
+
103
+ column_list = [c.strip() for c in columns.split(",") if c.strip()]
104
+ for col in column_list:
105
+ if col not in df.columns:
106
+ return {
107
+ "status": "error",
108
+ "detail": f"Column '{col}' not found.",
109
+ "available_columns": list(df.columns),
110
+ }
111
+
112
+ new_rows = df.dropna(subset=column_list).to_dict(orient="records")
113
+ for r in new_rows:
114
+ r["_search_text"] = " ".join(str(r[col]) for col in column_list if r.get(col) is not None)
115
+
116
+ new_texts = [r["_search_text"] for r in new_rows]
117
+
118
+ store.add_rows(new_rows, new_texts)
119
+
120
+ return {"status": "success", "added_count": len(new_rows), "used_columns": column_list, "total_rows": len(store.rows)}
121
+ except Exception as e:
122
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
123
+
124
+
125
+ # -------------------------
126
+ # Search endpoint
127
+ # -------------------------
128
+ @app.get("/search")
129
+ async def search(
130
+ query: str,
131
+ top_k: int = 3,
132
+ mode: str = Query("semantic", enum=["semantic", "lexical", "hybrid"]),
133
+ alpha: float = 0.5,):
134
+ #mode: semantic | lexical | hybrid
135
+ #alpha: weight for semantic in hybrid (0..1)
136
+ try:
137
+ if mode == "semantic":
138
+ results = store.search(query, top_k=top_k)
139
+ elif mode == "lexical":
140
+ if store.bm25 is None:
141
+ return {"results": []}
142
+ tokenized_query = query.lower().split()
143
+ scores = store.bm25.get_scores(tokenized_query)
144
+ ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:top_k]
145
+ results = [{**store.rows[i], "score": float(score)} for i, score in ranked]
146
+ else:
147
+ results = store.hybrid_search(query, top_k=top_k, alpha=alpha)
148
+
149
+ return {"results": results}
150
+ except Exception as e:
151
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
152
+
153
+
154
+ # -------------------------
155
+ # Delete all data
156
+ # -------------------------
157
+ @app.delete("/delete_data")
158
+ async def delete_data():
159
+ try:
160
+ store.clear_vdb()
161
+ return {"status": "success", "message": "Vector DB cleared"}
162
+ except Exception as e:
163
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ faiss-cpu
2
+ sentence_transformers
3
+ numpy
4
+ pandas
5
+ scikit-learn
6
+ torch
7
+ transformers
8
+ uvicorn
9
+ fastapi
10
+ python-multipart
11
+ rank_bm25
12
+ torchvision
13
+ pillow
14
+ git+https://github.com/openai/CLIP.git
text_engine.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # text_engine.py
2
+ import os
3
+ import pickle
4
+ import logging
5
+ from typing import List, Optional
6
+ import numpy as np
7
+ from sentence_transformers import SentenceTransformer
8
+ import faiss
9
+ from rank_bm25 import BM25Okapi
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class Text_Search_Engine:
16
+ def __init__(
17
+ self,
18
+ base_folder: str = "vector_store",
19
+ model_name: str = "sentence-transformers/LaBSE",
20
+ index_type: str = "flat",
21
+ ):
22
+ self.base_folder = base_folder
23
+ self.embeddings_folder = os.path.join(base_folder, "embeddings")
24
+ self.docs_folder = os.path.join(base_folder, "documents")
25
+ os.makedirs(self.embeddings_folder, exist_ok=True)
26
+ os.makedirs(self.docs_folder, exist_ok=True)
27
+
28
+ self.model = SentenceTransformer(model_name)
29
+ self.index: Optional[faiss.Index] = None
30
+ self.rows: List[dict] = []
31
+ self.texts: List[str] = []
32
+ self.bm25: Optional[BM25Okapi] = None
33
+ self.index_type = index_type
34
+
35
+ # -------------------------
36
+ # Index creation utilities
37
+ # -------------------------
38
+ def _create_index(self, dimension: int, embeddings: np.ndarray):
39
+ if self.index_type == "flat":
40
+ self.index = faiss.IndexFlatL2(dimension)
41
+ elif self.index_type == "ivf":
42
+ nlist = max(1, min(256, len(embeddings) // 10))
43
+ quantizer = faiss.IndexFlatL2(dimension)
44
+ self.index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_L2)
45
+ self.index.train(np.array(embeddings).astype("float32"))
46
+ elif self.index_type == "hnsw":
47
+ self.index = faiss.IndexHNSWFlat(dimension, 32)
48
+ else:
49
+ raise ValueError(f"Unsupported index type: {self.index_type}")
50
+
51
+ def _persist(self):
52
+ try:
53
+ if self.index is not None:
54
+ faiss.write_index(self.index, os.path.join(self.embeddings_folder, "multilingual.index"))
55
+ with open(os.path.join(self.docs_folder, "rows.pkl"), "wb") as f:
56
+ pickle.dump(self.rows, f)
57
+ logger.info("Persisted index and rows to disk.")
58
+ except Exception as e:
59
+ logger.exception("Failed to persist index/rows: %s", e)
60
+
61
+ # -------------------------
62
+ # Core operations
63
+ # -------------------------
64
+ def encode_store(self, rows: List[dict], texts: List[str]):
65
+ try:
66
+ embeddings = self.model.encode(texts, convert_to_numpy=True)
67
+ dimension = embeddings.shape[1]
68
+ self._create_index(dimension, embeddings)
69
+ self.index.add(np.array(embeddings).astype("float32"))
70
+
71
+ self.rows = rows
72
+ self.texts = texts
73
+ tokenized_corpus = [t.lower().split() for t in texts]
74
+ self.bm25 = BM25Okapi(tokenized_corpus)
75
+
76
+ self._persist()
77
+ logger.info("Index built with %d rows (index_type=%s).", len(rows), self.index_type)
78
+ except Exception as e:
79
+ logger.exception("Error in encode_store: %s", e)
80
+ raise
81
+
82
+ def load(self):
83
+ try:
84
+ index_path = os.path.join(self.embeddings_folder, "multilingual.index")
85
+ rows_path = os.path.join(self.docs_folder, "rows.pkl")
86
+ if os.path.exists(index_path) and os.path.exists(rows_path):
87
+ self.index = faiss.read_index(index_path)
88
+ with open(rows_path, "rb") as f:
89
+ self.rows = pickle.load(f)
90
+ self.texts = [r["_search_text"] for r in self.rows]
91
+ tokenized_corpus = [t.lower().split() for t in self.texts]
92
+ self.bm25 = BM25Okapi(tokenized_corpus)
93
+ logger.info("Loaded index and %d rows from disk.", len(self.rows))
94
+ else:
95
+ logger.info("No persisted index/rows found.")
96
+ except Exception as e:
97
+ logger.exception("Error in load: %s", e)
98
+ raise
99
+
100
+ def add_rows(self, new_rows: List[dict], new_texts: List[str]):
101
+ try:
102
+ if not new_rows:
103
+ return
104
+
105
+ new_embeddings = self.model.encode(new_texts, convert_to_numpy=True).astype("float32")
106
+ if self.index is None:
107
+ self._create_index(new_embeddings.shape[1], new_embeddings)
108
+ self.index.add(new_embeddings)
109
+ else:
110
+ if isinstance(self.index, faiss.IndexIVFFlat) and not self.index.is_trained:
111
+ combined = np.vstack([self.model.encode(self.texts, convert_to_numpy=True).astype("float32"), new_embeddings]) if self.texts else new_embeddings
112
+ self.index.train(combined)
113
+ self.index.add(new_embeddings)
114
+
115
+ self.rows.extend(new_rows)
116
+ self.texts.extend(new_texts)
117
+ tokenized_corpus = [t.lower().split() for t in self.texts]
118
+ self.bm25 = BM25Okapi(tokenized_corpus)
119
+
120
+ self._persist()
121
+ logger.info("Added %d new rows. Total rows: %d", len(new_rows), len(self.rows))
122
+ except Exception as e:
123
+ logger.exception("Error in add_rows: %s", e)
124
+ raise
125
+
126
+ # -------------------------
127
+ # Search methods
128
+ # -------------------------
129
+ def search(self, query: str, top_k: int = 3):
130
+ try:
131
+ if self.index is None:
132
+ return []
133
+ query_emb = self.model.encode([query], convert_to_numpy=True).astype("float32")
134
+ k = min(top_k, len(self.rows))
135
+ distances, indices = self.index.search(query_emb, k=k)
136
+ results = [
137
+ {**self.rows[i], "distance": float(distances[0][j])}
138
+ for j, i in enumerate(indices[0])
139
+ ]
140
+ return sorted(results, key=lambda x: x["distance"])
141
+ except Exception as e:
142
+ logger.exception("Error in search: %s", e)
143
+ return []
144
+
145
+ def hybrid_search(self, query: str, top_k: int = 3, alpha: float = 0.5):
146
+ try:
147
+ if self.index is None or self.bm25 is None:
148
+ return []
149
+
150
+ query_emb = self.model.encode([query], convert_to_numpy=True).astype("float32")
151
+ distances, indices = self.index.search(query_emb, k=len(self.texts))
152
+ semantic_scores = {i: 1 / (1 + distances[0][j]) for j, i in enumerate(indices[0])}
153
+
154
+ tokenized_query = query.lower().split()
155
+ bm25_scores = self.bm25.get_scores(tokenized_query)
156
+ lexical_scores = {i: bm25_scores[i] for i in range(len(self.texts))}
157
+
158
+ combined = []
159
+ for i, row in enumerate(self.rows):
160
+ sem = semantic_scores.get(i, 0.0)
161
+ lex = lexical_scores.get(i, 0.0)
162
+ score = alpha * sem + (1 - alpha) * lex
163
+ combined.append({**row, "score": float(score)})
164
+
165
+ combined = sorted(combined, key=lambda x: x["score"], reverse=True)
166
+ return combined[:top_k]
167
+ except Exception as e:
168
+ logger.exception("Error in hybrid_search: %s", e)
169
+ return []
170
+
171
+ # -------------------------
172
+ # Utilities
173
+ # -------------------------
174
+ def clear_vdb(self):
175
+ try:
176
+ if self.index is not None:
177
+ try:
178
+ self.index.reset()
179
+ except Exception:
180
+ self.index = None
181
+ self.rows = []
182
+ self.texts = []
183
+ self.bm25 = None
184
+
185
+ index_path = os.path.join(self.embeddings_folder, "multilingual.index")
186
+ docs_path = os.path.join(self.docs_folder, "rows.pkl")
187
+ if os.path.exists(index_path):
188
+ os.remove(index_path)
189
+ if os.path.exists(docs_path):
190
+ os.remove(docs_path)
191
+ logger.info("Cleared vector DB and persisted files.")
192
+ except Exception as e:
193
+ logger.exception("Error in clear_vdb: %s", e)
194
+ raise