|
|
| import os, pickle, logging
|
| import faiss
|
| from sentence_transformers import SentenceTransformer
|
|
|
| logging.basicConfig(level=logging.INFO)
|
| logger = logging.getLogger(__name__)
|
|
|
| RELATIONS = "relations"
|
| REL_INDEX = f"{RELATIONS}/got_rels.faiss"
|
| REL_DATA = f"{RELATIONS}/got_rels_meta.pkl"
|
|
|
| logger.info("Loading relationship FAISS index...")
|
| rel_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2",
|
| device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu")
|
| rel_index = faiss.read_index(REL_INDEX)
|
| with open(REL_DATA, "rb") as f:
|
| rel_data = pickle.load(f)
|
| name_map = rel_data["name_map"]
|
|
|
| def batch_relationships(questions, top_k=3):
|
| batch_results = []
|
| for q in questions:
|
| q_upper = q.upper()
|
| candidates = []
|
| for variant in name_map.keys():
|
| if len(variant) < 3: continue
|
| if variant in q_upper or variant.replace(" ","") in q_upper.replace(" ",""):
|
| candidates.append(name_map[variant])
|
| candidates = list(dict.fromkeys(candidates))[:2]
|
| if not candidates:
|
| batch_results.append(["No known character relationships found"])
|
| continue
|
|
|
| query = f"Relationships of {' and '.join(candidates)} in Game of Thrones books"
|
| q_vec = rel_model.encode([query], normalize_embeddings=True, show_progress_bar=False).astype("float32")
|
| D, I = rel_index.search(q_vec, top_k*2)
|
| results = []
|
| seen = set()
|
| for idx in I[0]:
|
| if idx == -1: continue
|
| sent = rel_data["sentences"][idx]
|
| char = rel_data["metadata"][idx]["display_name"]
|
| if char not in seen:
|
| results.append(sent)
|
| seen.add(char)
|
| if len(results) >= top_k: break
|
| batch_results.append(results if results else ["No confirmed relationships found"])
|
| return batch_results
|
|
|