|
|
| import os, re, json, pickle, logging, numpy as np, faiss
|
| from tqdm.notebook import tqdm
|
| from sentence_transformers import SentenceTransformer
|
| from langchain_community.retrievers import BM25Retriever
|
| from langchain.docstore.document import Document
|
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| logger = logging.getLogger(__name__)
|
|
|
| WORK = "context"
|
| JSONL = f"{WORK}/rag_documents.jsonl"
|
| FAISS_INDEX = f"{WORK}/faiss_ivf.index"
|
| BM25_PICKLE = f"{WORK}/bm25_retriever.pkl"
|
|
|
| logger.info("Loading all RAG documents...")
|
| with open(JSONL, encoding='utf-8') as f:
|
| ALL_DOCS = [json.loads(line) for line in f]
|
|
|
| LINE_TO_TEXT = {i: doc["text"] for i, doc in enumerate(ALL_DOCS)}
|
| LINE_TO_META = {i: doc["metadata"] for i, doc in enumerate(ALL_DOCS)}
|
|
|
| class HybridRetriever:
|
| def __init__(self):
|
|
|
| self.faiss_index = faiss.read_index(FAISS_INDEX)
|
| logger.info(f"FAISS loaded ({self.faiss_index.ntotal:,} vectors)")
|
|
|
|
|
| self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2",
|
| device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu")
|
|
|
|
|
| if os.path.exists(BM25_PICKLE):
|
| self.bm25 = pickle.load(open(BM25_PICKLE, "rb"))
|
| logger.info("BM25 loaded")
|
| else:
|
| logger.info("Building BM25...")
|
| docs = [Document(page_content=re.sub(r"^Filename:.*\nFullPath:.*\n\n", "",
|
| doc["text"], flags=re.M),
|
| metadata=doc["metadata"]) for doc in ALL_DOCS]
|
| self.bm25 = BM25Retriever.from_documents(docs)
|
| self.bm25.k = 30
|
| pickle.dump(self.bm25, open(BM25_PICKLE, "wb"))
|
| logger.info("BM25 built and saved")
|
|
|
| def batch_retrieve(self, queries, top_k=3, faiss_k=10, bm25_k=3):
|
| qvecs = self.model.encode(queries, show_progress_bar=False, normalize_embeddings=True).astype("float32")
|
| D, I = self.faiss_index.search(qvecs, faiss_k)
|
|
|
| batch_results = []
|
| for qi, (scores, indices) in enumerate(zip(D, I)):
|
| results = []
|
| seen = set()
|
| for score, idx in zip(scores, indices):
|
| if idx == -1 or idx in seen: continue
|
| results.append({"score": float(score), "text": LINE_TO_TEXT[idx],
|
| "metadata": LINE_TO_META[idx], "source": "FAISS"})
|
| seen.add(idx)
|
| if len(results) >= top_k: break
|
|
|
|
|
| bm25_docs = self.bm25.invoke(queries[qi])
|
| for doc in bm25_docs[:bm25_k]:
|
| ln = doc.metadata.get("line_no")
|
| if ln in seen: continue
|
| results.append({"score": 0.0, "text": LINE_TO_TEXT.get(ln, ""),
|
| "metadata": LINE_TO_META.get(ln, doc.metadata), "source": "BM25"})
|
| seen.add(ln)
|
| if len(results) >= top_k: break
|
| batch_results.append(results)
|
| return batch_results
|
|
|
|
|
| retriever = HybridRetriever()
|
|
|