got_retreivers / qa_retreiver.py
hash-map's picture
Upload 5 files
dff5c6e verified
# qa_retriever.py
import os, pickle, faiss
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
CLEAN_JSON = "qa_pairs/asoiaf_qa_clean.json"
INDEX_FILE = "qa_pairs/faiss_index.index"
QA_DATA_FILE = "qa_pairs/qa_data.pkl"
EMBED_MODEL: Optional[SentenceTransformer] = None
INDEX = None
QA_PAIRS: List[Dict[str, Any]] = []
def _load_embed_model():
global EMBED_MODEL
if EMBED_MODEL is None:
EMBED_MODEL = SentenceTransformer(MODEL_NAME,
device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu")
return EMBED_MODEL
def build_or_load_index():
global INDEX, QA_PAIRS
if INDEX and QA_PAIRS: return INDEX, QA_PAIRS, EMBED_MODEL
INDEX = faiss.read_index(INDEX_FILE)
with open(QA_DATA_FILE, "rb") as f:
QA_PAIRS = pickle.load(f)
_load_embed_model()
return INDEX, QA_PAIRS, EMBED_MODEL
def search_topk(query: str, index=None, qa_pairs=None, model=None, k: int = 5):
"""
Returns up to `k` similar Q&A entries as list of dicts.
Handles single query string input safely.
"""
if not isinstance(query, list):
query_list = [query]
else:
query_list = query
if model is None:
model = _load_embed_model()
if index is None or qa_pairs is None:
index, qa_pairs, model = build_or_load_index()
q_vecs = model.encode(query_list, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False).astype("float32")
results = []
for q_vec in q_vecs:
scores, indices = index.search(q_vec[None, :], k*3)
seen = set()
q_results = []
for score, idx in zip(scores[0], indices[0]):
if len(q_results) >= k: break
if idx < 0 or idx >= len(qa_pairs): continue
q_text = qa_pairs[idx].get("question", "")
if q_text in seen: continue
seen.add(q_text)
raw_ans = qa_pairs[idx].get("answer", "")
clean_ans = raw_ans.split("\n\nReference:")[0].strip()
q_results.append({
"similarity": float(score),
"question": q_text,
"answer": clean_ans
})
results.append(q_results)
return results[0] if len(results) == 1 else results