got_retreivers / relationships_retreiver.py
hash-map's picture
Upload 5 files
dff5c6e verified
# relationship_retriever.py
import os, pickle, logging
import faiss
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
RELATIONS = "relations"
REL_INDEX = f"{RELATIONS}/got_rels.faiss"
REL_DATA = f"{RELATIONS}/got_rels_meta.pkl"
logger.info("Loading relationship FAISS index...")
rel_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2",
device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu")
rel_index = faiss.read_index(REL_INDEX)
with open(REL_DATA, "rb") as f:
rel_data = pickle.load(f)
name_map = rel_data["name_map"]
def batch_relationships(questions, top_k=3):
batch_results = []
for q in questions:
q_upper = q.upper()
candidates = []
for variant in name_map.keys():
if len(variant) < 3: continue
if variant in q_upper or variant.replace(" ","") in q_upper.replace(" ",""):
candidates.append(name_map[variant])
candidates = list(dict.fromkeys(candidates))[:2]
if not candidates:
batch_results.append(["No known character relationships found"])
continue
query = f"Relationships of {' and '.join(candidates)} in Game of Thrones books"
q_vec = rel_model.encode([query], normalize_embeddings=True, show_progress_bar=False).astype("float32")
D, I = rel_index.search(q_vec, top_k*2)
results = []
seen = set()
for idx in I[0]:
if idx == -1: continue
sent = rel_data["sentences"][idx]
char = rel_data["metadata"][idx]["display_name"]
if char not in seen:
results.append(sent)
seen.add(char)
if len(results) >= top_k: break
batch_results.append(results if results else ["No confirmed relationships found"])
return batch_results