# relationship_retriever.py import os, pickle, logging import faiss from sentence_transformers import SentenceTransformer logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) RELATIONS = "relations" REL_INDEX = f"{RELATIONS}/got_rels.faiss" REL_DATA = f"{RELATIONS}/got_rels_meta.pkl" logger.info("Loading relationship FAISS index...") rel_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu") rel_index = faiss.read_index(REL_INDEX) with open(REL_DATA, "rb") as f: rel_data = pickle.load(f) name_map = rel_data["name_map"] def batch_relationships(questions, top_k=3): batch_results = [] for q in questions: q_upper = q.upper() candidates = [] for variant in name_map.keys(): if len(variant) < 3: continue if variant in q_upper or variant.replace(" ","") in q_upper.replace(" ",""): candidates.append(name_map[variant]) candidates = list(dict.fromkeys(candidates))[:2] if not candidates: batch_results.append(["No known character relationships found"]) continue query = f"Relationships of {' and '.join(candidates)} in Game of Thrones books" q_vec = rel_model.encode([query], normalize_embeddings=True, show_progress_bar=False).astype("float32") D, I = rel_index.search(q_vec, top_k*2) results = [] seen = set() for idx in I[0]: if idx == -1: continue sent = rel_data["sentences"][idx] char = rel_data["metadata"][idx]["display_name"] if char not in seen: results.append(sent) seen.add(char) if len(results) >= top_k: break batch_results.append(results if results else ["No confirmed relationships found"]) return batch_results