|
|
| import os, pickle, faiss
|
| from sentence_transformers import SentenceTransformer
|
| from typing import List, Dict, Any, Optional
|
|
|
| MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| CLEAN_JSON = "qa_pairs/asoiaf_qa_clean.json"
|
| INDEX_FILE = "qa_pairs/faiss_index.index"
|
| QA_DATA_FILE = "qa_pairs/qa_data.pkl"
|
|
|
| EMBED_MODEL: Optional[SentenceTransformer] = None
|
| INDEX = None
|
| QA_PAIRS: List[Dict[str, Any]] = []
|
|
|
| def _load_embed_model():
|
| global EMBED_MODEL
|
| if EMBED_MODEL is None:
|
| EMBED_MODEL = SentenceTransformer(MODEL_NAME,
|
| device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu")
|
| return EMBED_MODEL
|
|
|
| def build_or_load_index():
|
| global INDEX, QA_PAIRS
|
| if INDEX and QA_PAIRS: return INDEX, QA_PAIRS, EMBED_MODEL
|
|
|
| INDEX = faiss.read_index(INDEX_FILE)
|
| with open(QA_DATA_FILE, "rb") as f:
|
| QA_PAIRS = pickle.load(f)
|
| _load_embed_model()
|
| return INDEX, QA_PAIRS, EMBED_MODEL
|
| def search_topk(query: str, index=None, qa_pairs=None, model=None, k: int = 5):
|
| """
|
| Returns up to `k` similar Q&A entries as list of dicts.
|
| Handles single query string input safely.
|
| """
|
| if not isinstance(query, list):
|
| query_list = [query]
|
| else:
|
| query_list = query
|
|
|
| if model is None:
|
| model = _load_embed_model()
|
| if index is None or qa_pairs is None:
|
| index, qa_pairs, model = build_or_load_index()
|
|
|
| q_vecs = model.encode(query_list, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False).astype("float32")
|
|
|
| results = []
|
| for q_vec in q_vecs:
|
| scores, indices = index.search(q_vec[None, :], k*3)
|
| seen = set()
|
| q_results = []
|
| for score, idx in zip(scores[0], indices[0]):
|
| if len(q_results) >= k: break
|
| if idx < 0 or idx >= len(qa_pairs): continue
|
| q_text = qa_pairs[idx].get("question", "")
|
| if q_text in seen: continue
|
| seen.add(q_text)
|
| raw_ans = qa_pairs[idx].get("answer", "")
|
| clean_ans = raw_ans.split("\n\nReference:")[0].strip()
|
| q_results.append({
|
| "similarity": float(score),
|
| "question": q_text,
|
| "answer": clean_ans
|
| })
|
| results.append(q_results)
|
|
|
| return results[0] if len(results) == 1 else results
|
|
|