import os import json import hashlib import random import pandas as pd from typing import Union, Tuple from threading import Lock import torch from pyserini.search.lucene import LuceneSearcher from pyserini.index.lucene import Document from pyserini.analysis import get_lucene_analyzer from pyserini.pyclass import autoclass from scipy.stats import hypergeom from subspace.tool import SubspaceBERTScore _CACHE_LOCK = Lock() _SEARCHER_CACHE = {} _HGF_ANALYZER = None _STANDARD_QUERY_PARSER_CLASS = None def _get_cached_searcher(index_path: str): with _CACHE_LOCK: searcher = _SEARCHER_CACHE.get(index_path) if searcher is None: searcher = LuceneSearcher(index_path) _SEARCHER_CACHE[index_path] = searcher return searcher def _get_cached_hgf_analyzer(): global _HGF_ANALYZER with _CACHE_LOCK: if _HGF_ANALYZER is None: _HGF_ANALYZER = get_lucene_analyzer( language='hgf_tokenizer', huggingFaceTokenizer='bert-base-uncased' ) return _HGF_ANALYZER # ============================================================================== # BM25 Search and Query Building # ============================================================================== def get_standard_query(query: str, field: str = "contents", analyzer=None): """ Runs Lucene's StandardQueryParser to get a parsed query object. """ if analyzer is None: analyzer = get_lucene_analyzer() global _STANDARD_QUERY_PARSER_CLASS if _STANDARD_QUERY_PARSER_CLASS is None: _STANDARD_QUERY_PARSER_CLASS = autoclass('org.apache.lucene.queryparser.flexible.standard.StandardQueryParser') query_parser = _STANDARD_QUERY_PARSER_CLASS() query_parser.setAnalyzer(analyzer) return query_parser.parse(query, field) def query_bm25_index(index_path: str, keywords: list, doc_count: int = None) -> pd.DataFrame: """Load index, run BM25 phrase search using custom HuggingFace analyzer, and return results.""" # 1. Load searcher searcher = _get_cached_searcher(index_path) # 2. Load custom analyzer matching your index strategy analyzer = _get_cached_hgf_analyzer() # 3. Create query string connecting keywords by OR (e.g., '"jesus" OR "christ"') query_string = " OR ".join([f'"{kw}"' for kw in keywords]) # 4. Build standard lucene query using your custom querybuilder phrase_q = get_standard_query(query_string, analyzer=analyzer) # 5. Search if doc_count is None: doc_count = searcher.num_docs hits = searcher.search(phrase_q, doc_count) # 6. Parse results results = [] returned_ids = set() for hit in hits: returned_ids.add(hit.docid) doc = Document(hit.lucene_document) raw = doc.raw() jd = json.loads(raw) row = { 'id': jd.get("id"), 'content': jd.get("contents", ""), 'score': hit.score } if "metadata" in jd and jd["metadata"]: metadata = json.loads(jd["metadata"]) row.update(metadata) results.append(row) returned_ext_ids = {r['id'] for r in results} # Pad with random unretrieved items exactly as required if len(results) < doc_count: needed = doc_count - len(results) total = searcher.num_docs # build a list of internal docnums whose external ID wasn't already returned pool = [] for docnum in range(total): lucene_doc = searcher.doc(docnum) doc = Document(lucene_doc) jd = json.loads(doc.raw()) ext_id = jd.get("id") if ext_id not in returned_ext_ids: pool.append(docnum) # deterministically shuffle by query md5 = hashlib.md5(query_string.encode("utf-8")).hexdigest() seed = int(md5, 16) % 2**32 rng = random.Random(seed) rng.shuffle(pool) # pull 'needed' more docs for docnum in pool[:needed]: lucene_doc = searcher.doc(docnum) doc = Document(lucene_doc) raw = doc.raw() jd = json.loads(raw) row = { "id": jd.get("id"), "content": jd.get("contents", ""), "score": None } if "metadata" in jd and jd["metadata"]: metadata = json.loads(jd["metadata"]) row.update(metadata) results.append(row) return pd.DataFrame(results) def warmup_bm25(index_path: str, warmup_keyword: str = "community", warmup_k: int = 1) -> None: """Pre-warm Lucene searcher/analyzer/parser so first user query is faster.""" if not index_path or not os.path.isdir(index_path): raise ValueError(f"Invalid index path for warmup: {index_path}") searcher = _get_cached_searcher(index_path) analyzer = _get_cached_hgf_analyzer() query = get_standard_query(f'"{warmup_keyword}"', analyzer=analyzer) searcher.search(query, max(1, int(warmup_k))) # ============================================================================== # Evaluation Metrics (Precision/Lift) # ============================================================================== def _resolve_k(df, k): """Convert float percentages to absolute k or return k as an int.""" if isinstance(k, float) and 0.0 < k <= 1.0: return int(len(df) * k) return int(k) def precision_at_k(df: pd.DataFrame, correct_demographic: str, k: Union[int, float]) -> float: """Calculate precision at k for a target demographic.""" rel = (df['demographic'] == correct_demographic).astype(int) k_abs = _resolve_k(df, k) if k_abs <= 0: return 0.0 return rel.iloc[:k_abs].sum() / float(k_abs) def lift_at_k(df: pd.DataFrame, correct_demographic: str, k: Union[int, float]) -> float: """Lift@k: ratio of precision@k to the overall proportion of relevant items.""" k_abs = _resolve_k(df, k) if k_abs <= 0 or len(df) == 0: return 0.0 precision_k = precision_at_k(df, correct_demographic, k) rel = (df['demographic'] == correct_demographic).astype(int) overall_proportion = rel.sum() / float(len(df)) if overall_proportion == 0: return 0.0 return precision_k / overall_proportion def hypergeometric_significance_test(df: pd.DataFrame, correct_demographic: str, k: Union[int, float], alpha: float = 0.05) -> Tuple[float, Tuple[int, int], Tuple[float, float]]: """Hypergeometric statistical significance test for the retrieval.""" n = _resolve_k(df, k) N = len(df) rel = (df['demographic'] == correct_demographic).astype(int) K = rel.sum() k_obs = rel.iloc[:n].sum() if K == 0 or n <= 0: return 0.0, (0, 0), (0.0, 0.0) p_value = hypergeom.sf(k_obs - 1, N, K, n) L = int(hypergeom.ppf(alpha/2, N, K, n)) U = int(hypergeom.isf(alpha/2, N, K, n)) return p_value, (L, U), (L / n, U / n) def lift_ci(df: pd.DataFrame, correct_demographic: str, k: Union[int, float], alpha: float = 0.05) -> Tuple[float, float, float]: """Calculate confidence interval for lift@k using hypergeometric distribution.""" n = _resolve_k(df, k) N = len(df) rel = (df['demographic'] == correct_demographic).astype(int) K = rel.sum() overall_proportion = K / float(N) if K == 0 or n <= 0 or overall_proportion == 0: return 0.0, 0.0, 0.0 pval, (L, U), _ = hypergeometric_significance_test(df, correct_demographic, k, alpha) lower_bound_lift = (L / n) / overall_proportion upper_bound_lift = (U / n) / overall_proportion return pval, lower_bound_lift, upper_bound_lift # ============================================================================== # Keyword Similarity (SubspaceBERTScore) # ============================================================================== _GLOBAL_SCORER = None def compute_keyword_similarity(set1: list, set2: list, device: str = None) -> dict: """ Computes precision, recall, and F-score similarity metrics between two keyword sets. Mirrors the subspace-based BERT scoring logic handling keyword lists. """ global _GLOBAL_SCORER if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' if _GLOBAL_SCORER is None: print(f"Initializing BERT model on {device}...") _GLOBAL_SCORER = SubspaceBERTScore(device=device, model_name_or_path='bert-base-uncased') sentence_1 = [", ".join(set1)] sentence_2 = [", ".join(set2)] scores = _GLOBAL_SCORER(sentence_1, sentence_2) return { 'Precision': scores[0].item(), 'Recall': scores[1].item(), 'F-Score': scores[2].item() }