# qa_retriever.py import os, pickle, faiss from sentence_transformers import SentenceTransformer from typing import List, Dict, Any, Optional MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" CLEAN_JSON = "qa_pairs/asoiaf_qa_clean.json" INDEX_FILE = "qa_pairs/faiss_index.index" QA_DATA_FILE = "qa_pairs/qa_data.pkl" EMBED_MODEL: Optional[SentenceTransformer] = None INDEX = None QA_PAIRS: List[Dict[str, Any]] = [] def _load_embed_model(): global EMBED_MODEL if EMBED_MODEL is None: EMBED_MODEL = SentenceTransformer(MODEL_NAME, device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu") return EMBED_MODEL def build_or_load_index(): global INDEX, QA_PAIRS if INDEX and QA_PAIRS: return INDEX, QA_PAIRS, EMBED_MODEL INDEX = faiss.read_index(INDEX_FILE) with open(QA_DATA_FILE, "rb") as f: QA_PAIRS = pickle.load(f) _load_embed_model() return INDEX, QA_PAIRS, EMBED_MODEL def search_topk(query: str, index=None, qa_pairs=None, model=None, k: int = 5): """ Returns up to `k` similar Q&A entries as list of dicts. Handles single query string input safely. """ if not isinstance(query, list): query_list = [query] else: query_list = query if model is None: model = _load_embed_model() if index is None or qa_pairs is None: index, qa_pairs, model = build_or_load_index() q_vecs = model.encode(query_list, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False).astype("float32") results = [] for q_vec in q_vecs: scores, indices = index.search(q_vec[None, :], k*3) seen = set() q_results = [] for score, idx in zip(scores[0], indices[0]): if len(q_results) >= k: break if idx < 0 or idx >= len(qa_pairs): continue q_text = qa_pairs[idx].get("question", "") if q_text in seen: continue seen.add(q_text) raw_ans = qa_pairs[idx].get("answer", "") clean_ans = raw_ans.split("\n\nReference:")[0].strip() q_results.append({ "similarity": float(score), "question": q_text, "answer": clean_ans }) results.append(q_results) return results[0] if len(results) == 1 else results