Eylon Caplan commited on
Commit ·
ddb7b62
1
Parent(s): c82bc86
Deploy app code targeting HF Storage Bucket
Browse files- .gitignore +1 -0
- app.py +96 -0
- core_logic.py +211 -0
- packages.txt +1 -0
- requirements.txt +6 -0
- subspace/__init__.py +9 -0
- subspace/__pycache__/__init__.cpython-310.pyc +0 -0
- subspace/__pycache__/fuzzy.cpython-310.pyc +0 -0
- subspace/__pycache__/similarity.cpython-310.pyc +0 -0
- subspace/__pycache__/symbolic.cpython-310.pyc +0 -0
- subspace/__pycache__/tool.cpython-310.pyc +0 -0
- subspace/fuzzy.py +47 -0
- subspace/grassmannian.py +26 -0
- subspace/legacy_operations/__init__.py +1 -0
- subspace/legacy_operations/operations.py +103 -0
- subspace/operations.py +108 -0
- subspace/optimal_transport.py +47 -0
- subspace/similarity.py +162 -0
- subspace/symbolic.py +29 -0
- subspace/tool.py +72 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
bm25_indexes/
|
app.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import os
|
| 4 |
+
from core_logic import (
|
| 5 |
+
query_bm25_index,
|
| 6 |
+
lift_at_k,
|
| 7 |
+
lift_ci,
|
| 8 |
+
compute_keyword_similarity
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# Path to the index directory where the Hugging Face Storage Bucket is mounted.
|
| 12 |
+
# Assuming the bucket is mounted at /data in the Space settings.
|
| 13 |
+
INDEX_DIR = '/data/bm25_indexes'
|
| 14 |
+
|
| 15 |
+
def get_available_indices():
|
| 16 |
+
if os.path.exists(INDEX_DIR):
|
| 17 |
+
return [d for d in os.listdir(INDEX_DIR) if os.path.isdir(os.path.join(INDEX_DIR, d))]
|
| 18 |
+
return ["No indices found"]
|
| 19 |
+
|
| 20 |
+
def evaluate_keywords(index_name, target_demo, seed_words_str, generated_words_str):
|
| 21 |
+
try:
|
| 22 |
+
# Parse inputs
|
| 23 |
+
seed_words = [w.strip() for w in seed_words_str.split(",") if w.strip()]
|
| 24 |
+
generated_words = [w.strip() for w in generated_words_str.split(",") if w.strip()]
|
| 25 |
+
|
| 26 |
+
index_path = os.path.join(INDEX_DIR, index_name)
|
| 27 |
+
|
| 28 |
+
# 1. Compute BM25 Lifts for the GENERATED words
|
| 29 |
+
df_results = query_bm25_index(index_path, generated_words, doc_count=1000)
|
| 30 |
+
|
| 31 |
+
lift_100 = lift_at_k(df_results, target_demo, k=100)
|
| 32 |
+
pval_100, ci_lower_100, ci_upper_100 = lift_ci(df_results, target_demo, k=100)
|
| 33 |
+
|
| 34 |
+
lift_5_percent = lift_at_k(df_results, target_demo, k=0.05)
|
| 35 |
+
pval_5, ci_lower_5, ci_upper_5 = lift_ci(df_results, target_demo, k=0.05)
|
| 36 |
+
|
| 37 |
+
lift_text = (
|
| 38 |
+
f"**Lift@100:** {lift_100:.3f} (p={pval_100:.4f}, 95% CI: [{ci_lower_100:.3f}, {ci_upper_100:.3f}])\n"
|
| 39 |
+
f"**Lift@5%:** {lift_5_percent:.3f} (p={pval_5:.4f}, 95% CI: [{ci_lower_5:.3f}, {ci_upper_5:.3f}])"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 2. Compute BERT Similarity
|
| 43 |
+
sim_metrics = compute_keyword_similarity(seed_words, generated_words, device='cpu')
|
| 44 |
+
sim_text = (
|
| 45 |
+
f"**Precision:** {sim_metrics['Precision']:.4f}\n"
|
| 46 |
+
f"**Recall:** {sim_metrics['Recall']:.4f}\n"
|
| 47 |
+
f"**F-Score:** {sim_metrics['F-Score']:.4f}"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# 3. Preview Top 10 hits
|
| 51 |
+
top_hits = df_results.head(10)[['id', 'score', 'demographic', 'content']]
|
| 52 |
+
|
| 53 |
+
return lift_text, sim_text, top_hits
|
| 54 |
+
|
| 55 |
+
except Exception as e:
|
| 56 |
+
return f"Error: {str(e)}", "", pd.DataFrame()
|
| 57 |
+
|
| 58 |
+
# Gradio Interface
|
| 59 |
+
with gr.Blocks(title="BM25 Splits Demo") as demo:
|
| 60 |
+
gr.Markdown("# 🚀 BM25 Target Demographic Evaluation Demo")
|
| 61 |
+
gr.Markdown("Test retrieved demographic splits against predefined seed keywords and BERT Subspace metrics.")
|
| 62 |
+
|
| 63 |
+
with gr.Row():
|
| 64 |
+
with gr.Column():
|
| 65 |
+
index_dropdown = gr.Dropdown(choices=get_available_indices(), label="Select BM25 Index")
|
| 66 |
+
target_demo_input = gr.Textbox(label="Target Demographic (e.g., 'jewish', 'black')", value="jewish")
|
| 67 |
+
|
| 68 |
+
seed_words_input = gr.Textbox(
|
| 69 |
+
label="Target Demographic Seed Words (Comma separated)",
|
| 70 |
+
value="the, be, to, of, and, a, in, that, have, I, it, for, not, on, with, he, as, you, do, at"
|
| 71 |
+
)
|
| 72 |
+
generated_words_input = gr.Textbox(
|
| 73 |
+
label="Your Subspace/Generated Keywords (Comma separated)",
|
| 74 |
+
value="church, jesus, christ, prayer"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
submit_btn = gr.Button("Run Compute", variant="primary")
|
| 78 |
+
|
| 79 |
+
with gr.Column():
|
| 80 |
+
gr.Markdown("### 📊 Similarity Metrics (BERT-Score)")
|
| 81 |
+
sim_output = gr.Markdown("Waiting to run...")
|
| 82 |
+
|
| 83 |
+
gr.Markdown("### 📈 Lift Metrics (BM25)")
|
| 84 |
+
lift_output = gr.Markdown("Waiting to run...")
|
| 85 |
+
|
| 86 |
+
gr.Markdown("### 🔍 Top 10 Retrieved Hits")
|
| 87 |
+
table_output = gr.Dataframe()
|
| 88 |
+
|
| 89 |
+
submit_btn.click(
|
| 90 |
+
fn=evaluate_keywords,
|
| 91 |
+
inputs=[index_dropdown, target_demo_input, seed_words_input, generated_words_input],
|
| 92 |
+
outputs=[lift_output, sim_output, table_output]
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
core_logic.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import hashlib
|
| 4 |
+
import random
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from typing import Union, Tuple
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from pyserini.search.lucene import LuceneSearcher
|
| 10 |
+
from pyserini.index.lucene import Document
|
| 11 |
+
from pyserini.analysis import get_lucene_analyzer
|
| 12 |
+
from pyserini.pyclass import autoclass
|
| 13 |
+
from scipy.stats import hypergeom
|
| 14 |
+
from subspace.tool import SubspaceBERTScore
|
| 15 |
+
|
| 16 |
+
# ==============================================================================
|
| 17 |
+
# BM25 Search and Query Building
|
| 18 |
+
# ==============================================================================
|
| 19 |
+
|
| 20 |
+
def get_standard_query(query: str, field: str = "contents", analyzer=None):
|
| 21 |
+
"""
|
| 22 |
+
Runs Lucene's StandardQueryParser to get a parsed query object.
|
| 23 |
+
"""
|
| 24 |
+
if analyzer is None:
|
| 25 |
+
analyzer = get_lucene_analyzer()
|
| 26 |
+
|
| 27 |
+
JStandardQueryParser = autoclass('org.apache.lucene.queryparser.flexible.standard.StandardQueryParser')
|
| 28 |
+
query_parser = JStandardQueryParser()
|
| 29 |
+
query_parser.setAnalyzer(analyzer)
|
| 30 |
+
|
| 31 |
+
return query_parser.parse(query, field)
|
| 32 |
+
|
| 33 |
+
def query_bm25_index(index_path: str, keywords: list, doc_count: int = 1000) -> pd.DataFrame:
|
| 34 |
+
"""Load index, run BM25 phrase search using custom HuggingFace analyzer, and return results."""
|
| 35 |
+
# 1. Load searcher
|
| 36 |
+
searcher = LuceneSearcher(index_path)
|
| 37 |
+
|
| 38 |
+
# 2. Load custom analyzer matching your index strategy
|
| 39 |
+
analyzer = get_lucene_analyzer(
|
| 40 |
+
language='hgf_tokenizer',
|
| 41 |
+
huggingFaceTokenizer='bert-base-uncased'
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# 3. Create query string connecting keywords by OR (e.g., '"jesus" OR "christ"')
|
| 45 |
+
query_string = " OR ".join([f'"{kw}"' for kw in keywords])
|
| 46 |
+
|
| 47 |
+
# 4. Build standard lucene query using your custom querybuilder
|
| 48 |
+
phrase_q = get_standard_query(query_string, analyzer=analyzer)
|
| 49 |
+
|
| 50 |
+
# 5. Search
|
| 51 |
+
hits = searcher.search(phrase_q, doc_count)
|
| 52 |
+
|
| 53 |
+
# 6. Parse results
|
| 54 |
+
results = []
|
| 55 |
+
returned_ids = set()
|
| 56 |
+
|
| 57 |
+
for hit in hits:
|
| 58 |
+
returned_ids.add(hit.docid)
|
| 59 |
+
doc = Document(hit.lucene_document)
|
| 60 |
+
raw = doc.raw()
|
| 61 |
+
jd = json.loads(raw)
|
| 62 |
+
|
| 63 |
+
row = {
|
| 64 |
+
'id': jd.get("id"),
|
| 65 |
+
'content': jd.get("contents", ""),
|
| 66 |
+
'score': hit.score
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
if "metadata" in jd and jd["metadata"]:
|
| 70 |
+
metadata = json.loads(jd["metadata"])
|
| 71 |
+
row.update(metadata)
|
| 72 |
+
|
| 73 |
+
results.append(row)
|
| 74 |
+
|
| 75 |
+
returned_ext_ids = {r['id'] for r in results}
|
| 76 |
+
|
| 77 |
+
# Pad with random unretrieved items exactly as required
|
| 78 |
+
if len(results) < doc_count:
|
| 79 |
+
needed = doc_count - len(results)
|
| 80 |
+
total = searcher.num_docs
|
| 81 |
+
|
| 82 |
+
# build a list of internal docnums whose external ID wasn't already returned
|
| 83 |
+
pool = []
|
| 84 |
+
for docnum in range(total):
|
| 85 |
+
lucene_doc = searcher.doc(docnum)
|
| 86 |
+
doc = Document(lucene_doc)
|
| 87 |
+
jd = json.loads(doc.raw())
|
| 88 |
+
ext_id = jd.get("id")
|
| 89 |
+
if ext_id not in returned_ext_ids:
|
| 90 |
+
pool.append(docnum)
|
| 91 |
+
|
| 92 |
+
# deterministically shuffle by query
|
| 93 |
+
md5 = hashlib.md5(query_string.encode("utf-8")).hexdigest()
|
| 94 |
+
seed = int(md5, 16) % 2**32
|
| 95 |
+
rng = random.Random(seed)
|
| 96 |
+
rng.shuffle(pool)
|
| 97 |
+
|
| 98 |
+
# pull 'needed' more docs
|
| 99 |
+
for docnum in pool[:needed]:
|
| 100 |
+
lucene_doc = searcher.doc(docnum)
|
| 101 |
+
doc = Document(lucene_doc)
|
| 102 |
+
raw = doc.raw()
|
| 103 |
+
jd = json.loads(raw)
|
| 104 |
+
|
| 105 |
+
row = {
|
| 106 |
+
"id": jd.get("id"),
|
| 107 |
+
"content": jd.get("contents", ""),
|
| 108 |
+
"score": None
|
| 109 |
+
}
|
| 110 |
+
if "metadata" in jd and jd["metadata"]:
|
| 111 |
+
metadata = json.loads(jd["metadata"])
|
| 112 |
+
row.update(metadata)
|
| 113 |
+
|
| 114 |
+
results.append(row)
|
| 115 |
+
|
| 116 |
+
return pd.DataFrame(results)
|
| 117 |
+
|
| 118 |
+
# ==============================================================================
|
| 119 |
+
# Evaluation Metrics (Precision/Lift)
|
| 120 |
+
# ==============================================================================
|
| 121 |
+
|
| 122 |
+
def _resolve_k(df, k):
|
| 123 |
+
"""Convert float percentages to absolute k or return k as an int."""
|
| 124 |
+
if isinstance(k, float) and 0.0 < k <= 1.0:
|
| 125 |
+
return int(len(df) * k)
|
| 126 |
+
return int(k)
|
| 127 |
+
|
| 128 |
+
def precision_at_k(df: pd.DataFrame, correct_demographic: str, k: Union[int, float]) -> float:
|
| 129 |
+
"""Calculate precision at k for a target demographic."""
|
| 130 |
+
rel = (df['demographic'] == correct_demographic).astype(int)
|
| 131 |
+
k_abs = _resolve_k(df, k)
|
| 132 |
+
if k_abs <= 0:
|
| 133 |
+
return 0.0
|
| 134 |
+
return rel.iloc[:k_abs].sum() / float(k_abs)
|
| 135 |
+
|
| 136 |
+
def lift_at_k(df: pd.DataFrame, correct_demographic: str, k: Union[int, float]) -> float:
|
| 137 |
+
"""Lift@k: ratio of precision@k to the overall proportion of relevant items."""
|
| 138 |
+
k_abs = _resolve_k(df, k)
|
| 139 |
+
if k_abs <= 0 or len(df) == 0:
|
| 140 |
+
return 0.0
|
| 141 |
+
|
| 142 |
+
precision_k = precision_at_k(df, correct_demographic, k)
|
| 143 |
+
rel = (df['demographic'] == correct_demographic).astype(int)
|
| 144 |
+
overall_proportion = rel.sum() / float(len(df))
|
| 145 |
+
|
| 146 |
+
if overall_proportion == 0:
|
| 147 |
+
return 0.0
|
| 148 |
+
|
| 149 |
+
return precision_k / overall_proportion
|
| 150 |
+
|
| 151 |
+
def hypergeometric_significance_test(df: pd.DataFrame, correct_demographic: str, k: Union[int, float], alpha: float = 0.05) -> Tuple[float, Tuple[int, int], Tuple[float, float]]:
|
| 152 |
+
"""Hypergeometric statistical significance test for the retrieval."""
|
| 153 |
+
n = _resolve_k(df, k)
|
| 154 |
+
N = len(df)
|
| 155 |
+
|
| 156 |
+
rel = (df['demographic'] == correct_demographic).astype(int)
|
| 157 |
+
K = rel.sum()
|
| 158 |
+
k_obs = rel.iloc[:n].sum()
|
| 159 |
+
|
| 160 |
+
if K == 0 or n <= 0:
|
| 161 |
+
return 0.0, (0, 0), (0.0, 0.0)
|
| 162 |
+
|
| 163 |
+
p_value = hypergeom.sf(k_obs - 1, N, K, n)
|
| 164 |
+
L = int(hypergeom.ppf(alpha/2, N, K, n))
|
| 165 |
+
U = int(hypergeom.isf(alpha/2, N, K, n))
|
| 166 |
+
|
| 167 |
+
return p_value, (L, U), (L / n, U / n)
|
| 168 |
+
|
| 169 |
+
def lift_ci(df: pd.DataFrame, correct_demographic: str, k: Union[int, float], alpha: float = 0.05) -> Tuple[float, float, float]:
|
| 170 |
+
"""Calculate confidence interval for lift@k using hypergeometric distribution."""
|
| 171 |
+
n = _resolve_k(df, k)
|
| 172 |
+
N = len(df)
|
| 173 |
+
|
| 174 |
+
rel = (df['demographic'] == correct_demographic).astype(int)
|
| 175 |
+
K = rel.sum()
|
| 176 |
+
overall_proportion = K / float(N)
|
| 177 |
+
|
| 178 |
+
if K == 0 or n <= 0 or overall_proportion == 0:
|
| 179 |
+
return 0.0, 0.0, 0.0
|
| 180 |
+
|
| 181 |
+
pval, (L, U), _ = hypergeometric_significance_test(df, correct_demographic, k, alpha)
|
| 182 |
+
lower_bound_lift = (L / n) / overall_proportion
|
| 183 |
+
upper_bound_lift = (U / n) / overall_proportion
|
| 184 |
+
|
| 185 |
+
return pval, lower_bound_lift, upper_bound_lift
|
| 186 |
+
|
| 187 |
+
# ==============================================================================
|
| 188 |
+
# Keyword Similarity (SubspaceBERTScore)
|
| 189 |
+
# ==============================================================================
|
| 190 |
+
|
| 191 |
+
def compute_keyword_similarity(set1: list, set2: list, device: str = None) -> dict:
|
| 192 |
+
"""
|
| 193 |
+
Computes precision, recall, and F-score similarity metrics between two keyword sets.
|
| 194 |
+
Mirrors the subspace-based BERT scoring logic handling keyword lists.
|
| 195 |
+
"""
|
| 196 |
+
if device is None:
|
| 197 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 198 |
+
|
| 199 |
+
print(f"Initializing BERT model on {device}...")
|
| 200 |
+
scorer = SubspaceBERTScore(device=device, model_name_or_path='bert-base-uncased')
|
| 201 |
+
|
| 202 |
+
sentence_1 = [", ".join(set1)]
|
| 203 |
+
sentence_2 = [", ".join(set2)]
|
| 204 |
+
|
| 205 |
+
scores = scorer(sentence_1, sentence_2)
|
| 206 |
+
|
| 207 |
+
return {
|
| 208 |
+
'Precision': scores[0].item(),
|
| 209 |
+
'Recall': scores[1].item(),
|
| 210 |
+
'F-Score': scores[2].item()
|
| 211 |
+
}
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
default-jre
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pyserini==0.44.0
|
| 2 |
+
pandas==2.2.3
|
| 3 |
+
scipy==1.15.2
|
| 4 |
+
transformers==4.53.2
|
| 5 |
+
torch
|
| 6 |
+
gradio
|
subspace/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .similarity import subspace_johnson
|
| 2 |
+
from .similarity import vanilla_bert_score
|
| 3 |
+
from .similarity import subspace_bert_score
|
| 4 |
+
|
| 5 |
+
# Other metrics
|
| 6 |
+
from .fuzzy import *
|
| 7 |
+
from .symbolic import *
|
| 8 |
+
#from .optimal_transport import *
|
| 9 |
+
#from .grassmannian import *
|
subspace/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (322 Bytes). View file
|
|
|
subspace/__pycache__/fuzzy.cpython-310.pyc
ADDED
|
Binary file (1.15 kB). View file
|
|
|
subspace/__pycache__/similarity.cpython-310.pyc
ADDED
|
Binary file (5.61 kB). View file
|
|
|
subspace/__pycache__/symbolic.cpython-310.pyc
ADDED
|
Binary file (1.09 kB). View file
|
|
|
subspace/__pycache__/tool.cpython-310.pyc
ADDED
|
Binary file (2.84 kB). View file
|
|
|
subspace/fuzzy.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2018 Babylon Partners. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def fuzzify(s, u):
|
| 20 |
+
"""
|
| 21 |
+
Sentence fuzzifier.
|
| 22 |
+
Computes membership vector for the sentence S with respect to the
|
| 23 |
+
universe U
|
| 24 |
+
:param s: list of word embeddings for the sentence
|
| 25 |
+
:param u: the universe matrix U with shape (K, d)
|
| 26 |
+
:return: membership vectors for the sentence
|
| 27 |
+
"""
|
| 28 |
+
f_s = np.dot(s, u.T)
|
| 29 |
+
m_s = np.max(f_s, axis=0)
|
| 30 |
+
m_s = np.maximum(m_s, 0, m_s)
|
| 31 |
+
return m_s
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def dynamax_jaccard(x, y):
|
| 35 |
+
"""
|
| 36 |
+
DynaMax-Jaccard similarity measure between two sentences
|
| 37 |
+
:param x: list of word embeddings for the first sentence
|
| 38 |
+
:param y: list of word embeddings for the second sentence
|
| 39 |
+
:return: similarity score between the two sentences
|
| 40 |
+
"""
|
| 41 |
+
u = np.vstack((x, y))
|
| 42 |
+
m_x = fuzzify(x, u)
|
| 43 |
+
m_y = fuzzify(y, u)
|
| 44 |
+
|
| 45 |
+
m_inter = np.sum(np.minimum(m_x, m_y))
|
| 46 |
+
m_union = np.sum(np.maximum(m_x, m_y))
|
| 47 |
+
return m_inter / m_union
|
subspace/grassmannian.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import scipy
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def grassmann_distance(U, V):
|
| 6 |
+
""" Compute geodesic distance for grassmann manifold
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
U, V: A matrix of bases of a linear subspace
|
| 10 |
+
Return:
|
| 11 |
+
grassmann distance
|
| 12 |
+
See Also:
|
| 13 |
+
scipy.linalg.subspace_angles
|
| 14 |
+
Example:
|
| 15 |
+
>>> U = np.array([[1,0,0], [1,1,1]])
|
| 16 |
+
>>> V = np.array([[0,1,0], [1,1,1]])
|
| 17 |
+
>>> grassmann_distance(U, V)
|
| 18 |
+
"""
|
| 19 |
+
# compute the canonical angles
|
| 20 |
+
s = scipy.linalg.subspace_angles(U.T, V.T)
|
| 21 |
+
# grassmann distance
|
| 22 |
+
return sum(s * s)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def grassmann_similarity(x, y):
|
| 26 |
+
return -grassmann_distance(x, y)
|
subspace/legacy_operations/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .operations import *
|
subspace/legacy_operations/operations.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.linalg import orth
|
| 3 |
+
|
| 4 |
+
def subspace_np(A):
|
| 5 |
+
""" Compute orthonormal bases of the subspace
|
| 6 |
+
Args:
|
| 7 |
+
A: bases of the linear subspace (n_bases, dim)
|
| 8 |
+
Return:
|
| 9 |
+
Orthonormal bases
|
| 10 |
+
Example:
|
| 11 |
+
>>> A = np.random.random_sample((10, 300))
|
| 12 |
+
>>> subspace_np(A)
|
| 13 |
+
"""
|
| 14 |
+
return orth(A.T).T
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def intersection_np(SA, SB, threshold=1e-2):
|
| 18 |
+
""" Compute bases of the intersection
|
| 19 |
+
Args:
|
| 20 |
+
SA, SB: bases of the linear subspace (n_bases, dim)
|
| 21 |
+
Return:
|
| 22 |
+
Bases of intersection
|
| 23 |
+
Example:
|
| 24 |
+
>>> A = np.random.random_sample((10, 300))
|
| 25 |
+
>>> B = np.random.random_sample((20, 300))
|
| 26 |
+
>>> intersection_np(A, B)
|
| 27 |
+
"""
|
| 28 |
+
assert threshold > 1e-6
|
| 29 |
+
|
| 30 |
+
if SA.shape[0] > SB.shape[0]:
|
| 31 |
+
return intersection_np(SB, SA, threshold)
|
| 32 |
+
|
| 33 |
+
# orthonormalize
|
| 34 |
+
SA = subspace_np(SA)
|
| 35 |
+
SB = subspace_np(SB)
|
| 36 |
+
|
| 37 |
+
# compute canonical angles
|
| 38 |
+
u, s, v = np.linalg.svd(SA @ SB.T)
|
| 39 |
+
|
| 40 |
+
# extract the basis that the canonical angle is zero
|
| 41 |
+
u = u[:, np.abs(s - 1.0) < threshold]
|
| 42 |
+
return (SA.T @ u).T
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def sum_space_np(SA, SB):
|
| 46 |
+
""" Compute bases of the sum space
|
| 47 |
+
Args:
|
| 48 |
+
SA, SB: bases of the linear subspace (n_bases, dim)
|
| 49 |
+
Return:
|
| 50 |
+
Bases of sum space
|
| 51 |
+
Example:
|
| 52 |
+
>>> A = np.random.random_sample((10, 300))
|
| 53 |
+
>>> B = np.random.random_sample((20, 300))
|
| 54 |
+
>>> sum_space_np(A, B)
|
| 55 |
+
"""
|
| 56 |
+
M = np.concatenate([SA, SB], axis=0)
|
| 57 |
+
return subspace_np(M)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def orthogonal_complement_np(SA, threshold=1e-2):
|
| 61 |
+
""" Compute bases of the orthogonal complement
|
| 62 |
+
Args:
|
| 63 |
+
SA: bases of the linear subspace (n_bases, dim)
|
| 64 |
+
Return:
|
| 65 |
+
Bases of the orthogonal complement
|
| 66 |
+
Example:
|
| 67 |
+
>>> A = np.random.random_sample((10, 300))
|
| 68 |
+
>>> orthogonal_complement_np(A)
|
| 69 |
+
"""
|
| 70 |
+
assert threshold > 1e-6
|
| 71 |
+
u, s, v = np.linalg.svd(SA.T)
|
| 72 |
+
# compute rank
|
| 73 |
+
rank = (s > threshold).sum()
|
| 74 |
+
return u[:, rank:].T
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def soft_membership_np(A, v):
|
| 78 |
+
""" Compute membership degree of the vector v for the subspace A
|
| 79 |
+
Args:
|
| 80 |
+
A: bases of the linear subspace (n_bases, dim)
|
| 81 |
+
v: vector (dim,)
|
| 82 |
+
Return:
|
| 83 |
+
soft membership degree
|
| 84 |
+
Example:
|
| 85 |
+
>>> A = np.array([[1,0,0], [0,1,0]])
|
| 86 |
+
>>> v = np.array([1,0,0])
|
| 87 |
+
>>> soft_membership_np(A, v)
|
| 88 |
+
1.0
|
| 89 |
+
>>> A = np.array([[1,0,0], [0,1,0]])
|
| 90 |
+
>>> v = np.array([0,0,1])
|
| 91 |
+
>>> soft_membership_np(A, v)
|
| 92 |
+
0.0
|
| 93 |
+
"""
|
| 94 |
+
v = v.reshape(1, len(v))
|
| 95 |
+
v = subspace_np(v)
|
| 96 |
+
A = subspace_np(A)
|
| 97 |
+
|
| 98 |
+
# The cosine of the angles between a subspace and a vector are singular values
|
| 99 |
+
u, s, v = np.linalg.svd(A @ v.T)
|
| 100 |
+
s[s > 1] = 1
|
| 101 |
+
|
| 102 |
+
# Return the maximum cosine of the canonical angles, i.e., the soft membership.
|
| 103 |
+
return np.max(s)
|
subspace/operations.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def subspace(A: torch.Tensor) -> torch.Tensor:
|
| 4 |
+
"""
|
| 5 |
+
Compute orthonormal bases of the subspace
|
| 6 |
+
Args:
|
| 7 |
+
A: bases of the linear subspace (n_bases, dim)
|
| 8 |
+
Return:
|
| 9 |
+
Orthonormal bases
|
| 10 |
+
Example:
|
| 11 |
+
>>> A = torch.rand(10, 300)
|
| 12 |
+
>>> subspace(A)
|
| 13 |
+
"""
|
| 14 |
+
return torch.linalg.qr(A.t()).Q.t()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def intersection(SA: torch.Tensor, SB: torch.Tensor, threshold: float = 1e-2) -> torch.Tensor:
|
| 18 |
+
"""
|
| 19 |
+
Compute bases of the intersection
|
| 20 |
+
Args:
|
| 21 |
+
SA, SB: bases of the linear subspace (n_bases, dim)
|
| 22 |
+
Return:
|
| 23 |
+
Bases of intersection
|
| 24 |
+
Example:
|
| 25 |
+
>>> A = torch.rand(10, 300)
|
| 26 |
+
>>> B = torch.rand(20, 300)
|
| 27 |
+
>>> intersection(A, B)
|
| 28 |
+
"""
|
| 29 |
+
assert threshold > 1e-6
|
| 30 |
+
|
| 31 |
+
if SA.shape[0] > SB.shape[0]:
|
| 32 |
+
return intersection(SB, SA, threshold)
|
| 33 |
+
|
| 34 |
+
# orthonormalize
|
| 35 |
+
SA = subspace(SA)
|
| 36 |
+
SB = subspace(SB)
|
| 37 |
+
|
| 38 |
+
# compute canonical angles
|
| 39 |
+
u, s, v = torch.linalg.svd(SA @ SB.t())
|
| 40 |
+
|
| 41 |
+
# extract the basis that the canonical angle is zero
|
| 42 |
+
u = u[:, (s - 1.0).abs() < threshold]
|
| 43 |
+
return (SA.t() @ u).t()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def sum_space(SA: torch.Tensor, SB: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
"""
|
| 48 |
+
Compute bases of the sum space
|
| 49 |
+
Args:
|
| 50 |
+
SA, SB: bases of the linear subspace (n_bases, dim)
|
| 51 |
+
Return:
|
| 52 |
+
Bases of sum space
|
| 53 |
+
Example:
|
| 54 |
+
>>> A = torch.rand(10, 300)
|
| 55 |
+
>>> B = torch.rand(20, 300)
|
| 56 |
+
>>> sum_space(A, B)
|
| 57 |
+
"""
|
| 58 |
+
M = torch.cat([SA, SB], dim=0)
|
| 59 |
+
return subspace(M)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def orthogonal_complement(SA: torch.Tensor, threshold: float = 1e-2) -> torch.Tensor:
|
| 63 |
+
"""
|
| 64 |
+
Compute bases of the orthogonal complement
|
| 65 |
+
Args:
|
| 66 |
+
SA: bases of the linear subspace (n_bases, dim)
|
| 67 |
+
Return:
|
| 68 |
+
Bases of the orthogonal complement
|
| 69 |
+
Example:
|
| 70 |
+
>>> A = torch.rand(10, 300)
|
| 71 |
+
>>> orthogonal_complement(A)
|
| 72 |
+
"""
|
| 73 |
+
assert threshold > 1e-6
|
| 74 |
+
u, s, v = torch.linalg.svd(SA.t())
|
| 75 |
+
# compute rank
|
| 76 |
+
rank = (s > threshold).sum()
|
| 77 |
+
return u[:, rank:].T
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def soft_membership(A: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
"""
|
| 82 |
+
Compute membership degree of the vector v for the subspace A
|
| 83 |
+
Args:
|
| 84 |
+
A: bases of the linear subspace (n_bases, dim)
|
| 85 |
+
v: vector (dim,)
|
| 86 |
+
Return:
|
| 87 |
+
soft membership degree
|
| 88 |
+
Example:
|
| 89 |
+
>>> A = torch.tensor([[1,0,0], [0,1,0]])
|
| 90 |
+
>>> v = torch.tensor([1,0,0])
|
| 91 |
+
>>> soft_membership(A, v)
|
| 92 |
+
1.0
|
| 93 |
+
>>> A = torch.tensor([[1,0,0], [0,1,0]])
|
| 94 |
+
>>> v = torch.tensor([0,0,1])
|
| 95 |
+
>>> soft_membership(A, v)
|
| 96 |
+
0.0
|
| 97 |
+
"""
|
| 98 |
+
v = v.reshape(1, len(v))
|
| 99 |
+
v = subspace(v)
|
| 100 |
+
A = subspace(A)
|
| 101 |
+
|
| 102 |
+
# The cosine of the angles between a subspace and a vector are singular values
|
| 103 |
+
u, s, v = torch.linalg.svd(A @ v.t())
|
| 104 |
+
s[s > 1] = 1
|
| 105 |
+
|
| 106 |
+
# Return the maximum cosine of the canonical angles, i.e., the soft membership.
|
| 107 |
+
return torch.max(s)
|
| 108 |
+
|
subspace/optimal_transport.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://kexue.fm/archives/7388
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy.optimize import linprog
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def wasserstein_distance(p, q, D):
|
| 8 |
+
A_eq = []
|
| 9 |
+
for i in range(len(p)):
|
| 10 |
+
A = np.zeros_like(D)
|
| 11 |
+
A[i, :] = 1
|
| 12 |
+
A_eq.append(A.reshape(-1))
|
| 13 |
+
for i in range(len(q)):
|
| 14 |
+
A = np.zeros_like(D)
|
| 15 |
+
A[:, i] = 1
|
| 16 |
+
A_eq.append(A.reshape(-1))
|
| 17 |
+
A_eq = np.array(A_eq)
|
| 18 |
+
b_eq = np.concatenate([p, q])
|
| 19 |
+
D = D.reshape(-1)
|
| 20 |
+
result = linprog(D, A_eq=A_eq[:-1], b_eq=b_eq[:-1])
|
| 21 |
+
return result.fun
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def word_rotator_distance(x, y):
|
| 25 |
+
x_norm = (x**2).sum(axis=1, keepdims=True)**0.5
|
| 26 |
+
y_norm = (y**2).sum(axis=1, keepdims=True)**0.5
|
| 27 |
+
p = x_norm[:, 0] / x_norm.sum()
|
| 28 |
+
q = y_norm[:, 0] / y_norm.sum()
|
| 29 |
+
D = 1 - np.dot(x / x_norm, (y / y_norm).T)
|
| 30 |
+
return wasserstein_distance(p, q, D)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def word_mover_distance(x, y):
|
| 34 |
+
p = np.ones(x.shape[0]) / x.shape[0]
|
| 35 |
+
q = np.ones(y.shape[0]) / y.shape[0]
|
| 36 |
+
D = np.sqrt(np.square(x[:, None] - y[None, :]).mean(axis=2))
|
| 37 |
+
return wasserstein_distance(p, q, D)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def word_rotator_similarity(x, y):
|
| 41 |
+
return 1 - word_rotator_distance(x, y)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def word_mover_similarity(x, y):
|
| 45 |
+
return 1 - word_mover_distance(x, y)
|
| 46 |
+
|
| 47 |
+
|
subspace/similarity.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def get_weights(A, B, weight):
|
| 4 |
+
# get weights
|
| 5 |
+
if weight == "L2":
|
| 6 |
+
weights_A = torch.linalg.norm(A, dim=2)
|
| 7 |
+
weights_B = torch.linalg.norm(B, dim=2)
|
| 8 |
+
elif weight == "L1":
|
| 9 |
+
weights_A = torch.linalg.norm(A, dim=2, ord=1)
|
| 10 |
+
weights_B = torch.linalg.norm(B, dim=2, ord=1)
|
| 11 |
+
elif weight == "no":
|
| 12 |
+
weights_A = torch.ones(A.size(0), A.size(1)).to(A.device)
|
| 13 |
+
weights_B = torch.ones(B.size(0), B.size(1)).to(B.device)
|
| 14 |
+
else:
|
| 15 |
+
raise NotImplementedError
|
| 16 |
+
return weights_A, weights_B
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def pairwise_cosine_matrix(matrix1, matrix2):
|
| 20 |
+
dot = torch.matmul(matrix1, matrix2.transpose(1, 2))
|
| 21 |
+
matrix1_norm = torch.norm(matrix1, dim=-1, keepdim=True)
|
| 22 |
+
matrix2_norm = torch.norm(matrix2, dim=-1, keepdim=True)
|
| 23 |
+
norm = torch.matmul(matrix1_norm, matrix2_norm.transpose(1, 2))
|
| 24 |
+
return dot / norm
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def subspace_batch(A):
|
| 28 |
+
""" Return the matrix of the subspace for a batch of matrices
|
| 29 |
+
Arg:
|
| 30 |
+
A: Bases of a linear subspace (batchsize, num_bases, emb_dim)
|
| 31 |
+
Return:
|
| 32 |
+
S: Orthonormalized bases of a linear subspace (batchsize, num_bases, emb_dim)
|
| 33 |
+
Example:
|
| 34 |
+
>>> A = torch.randn(5, 4, 300)
|
| 35 |
+
>>> subspace_batch(A)
|
| 36 |
+
"""
|
| 37 |
+
# orthonormalize
|
| 38 |
+
S, _ = torch.linalg.qr(torch.transpose(A, 1, 2))
|
| 39 |
+
return torch.transpose(S, 1, 2)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@torch.jit.script
|
| 43 |
+
def soft_membership_batch(S, v):
|
| 44 |
+
""" Compute soft membership degree between a subspace and a vector for a batch of vectors
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
S: Orthonormalized bases of a linear subspace (batchsize, num_bases, emb_dim)
|
| 48 |
+
v: vector (batchsize, emb_dim)
|
| 49 |
+
Return:
|
| 50 |
+
soft_membership degree (batchsize,)
|
| 51 |
+
Example:
|
| 52 |
+
>>> S = torch.randn(5, 4, 300)
|
| 53 |
+
>>> v = torch.randn(5, 300)
|
| 54 |
+
>>> soft_membership_batch(S, v)
|
| 55 |
+
"""
|
| 56 |
+
# normalize
|
| 57 |
+
v = torch.nn.functional.normalize(v)
|
| 58 |
+
v = v.view(v.size(0), v.size(1), 1)
|
| 59 |
+
|
| 60 |
+
# compute SVD for cos(theta)
|
| 61 |
+
m = torch.matmul(S, v)
|
| 62 |
+
s = torch.linalg.svdvals(m.float()) # s is the sequence of cos(theta_i)
|
| 63 |
+
return torch.mean(s, 1)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def subspace_johnson(A, B, weight="L2"):
|
| 67 |
+
""" Compute similarity between two vector sets (sentences)
|
| 68 |
+
Args:
|
| 69 |
+
A: Matrix of word embeddings for the first sentence
|
| 70 |
+
(batchsize, num_bases, dim)
|
| 71 |
+
B: Matrix of word embeddings for the second sentence
|
| 72 |
+
(batchsize, num_bases, dim)
|
| 73 |
+
Return:
|
| 74 |
+
similarity between A and B (batchsize,)
|
| 75 |
+
Example:
|
| 76 |
+
>>> A = torch.randn(5, 3, 300)
|
| 77 |
+
>>> B = torch.randn(5, 4, 300)
|
| 78 |
+
>>> subspace_johnson(A, B)
|
| 79 |
+
"""
|
| 80 |
+
def numerator(U, V, weights):
|
| 81 |
+
"""
|
| 82 |
+
U should be a matrix of word embeddings
|
| 83 |
+
V should be a matrix of orthonormalized bases
|
| 84 |
+
"""
|
| 85 |
+
softm = torch.stack([soft_membership_batch(V, vec)
|
| 86 |
+
for vec in torch.transpose(U, 0, 1)])
|
| 87 |
+
softm = torch.transpose(softm, 0, 1)
|
| 88 |
+
return torch.sum(softm * weights, 1)
|
| 89 |
+
|
| 90 |
+
# get weights
|
| 91 |
+
weights_A, weights_B = get_weights(A, B, weight)
|
| 92 |
+
|
| 93 |
+
# compute similarity
|
| 94 |
+
x = numerator(A, subspace_batch(B), weights_A) / torch.sum(weights_A, 1)
|
| 95 |
+
y = numerator(B, subspace_batch(A), weights_B) / torch.sum(weights_B, 1)
|
| 96 |
+
return x + y
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def subspace_bert_score(A, B, weight="L2"):
|
| 101 |
+
""" Compute similarity between two vector sets (sentences)
|
| 102 |
+
Args:
|
| 103 |
+
A: Matrix of word embeddings for the first sentence
|
| 104 |
+
(batchsize, num_bases, dim)
|
| 105 |
+
B: Matrix of word embeddings for the second sentence
|
| 106 |
+
(batchsize, num_bases, dim)
|
| 107 |
+
Return:
|
| 108 |
+
similarity between A and B (batchsize,)
|
| 109 |
+
Example:
|
| 110 |
+
>>> A = torch.randn(5, 3, 300)
|
| 111 |
+
>>> B = torch.randn(5, 4, 300)
|
| 112 |
+
>>> subspace_bert_score(A, B)
|
| 113 |
+
"""
|
| 114 |
+
def numerator(U, V, weights):
|
| 115 |
+
"""
|
| 116 |
+
U should be a matrix of word embeddings
|
| 117 |
+
V should be a matrix of orthonormalized bases
|
| 118 |
+
"""
|
| 119 |
+
softm = torch.stack([soft_membership_batch(V, vec)
|
| 120 |
+
for vec in torch.transpose(U, 0, 1)])
|
| 121 |
+
softm = torch.transpose(softm, 0, 1)
|
| 122 |
+
return torch.sum(softm * weights, 1)
|
| 123 |
+
|
| 124 |
+
# get weights
|
| 125 |
+
weights_A, weights_B = get_weights(A, B, weight)
|
| 126 |
+
|
| 127 |
+
# Cmpute P, R, F
|
| 128 |
+
R = numerator(A, subspace_batch(B), weights_A) / torch.sum(weights_A, 1) # R is the left term of SubspaceJohnson
|
| 129 |
+
P = numerator(B, subspace_batch(A), weights_B) / torch.sum(weights_B, 1) # P is the right term of SubspaceJohnson
|
| 130 |
+
F = (2 * P * R) / (P + R)
|
| 131 |
+
return P, R, F
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def vanilla_bert_score(A, B, weight="L2"):
|
| 135 |
+
""" Compute similarity between two vector sets (sentences)
|
| 136 |
+
Args:
|
| 137 |
+
A: Matrix of word embeddings for the first sentence
|
| 138 |
+
(batchsize, num_bases, dim)
|
| 139 |
+
B: Matrix of word embeddings for the second sentence
|
| 140 |
+
(batchsize, num_bases, dim)
|
| 141 |
+
Return:
|
| 142 |
+
similarity between A and B (batchsize,)
|
| 143 |
+
Example:
|
| 144 |
+
>>> A = torch.randn(5, 3, 300)
|
| 145 |
+
>>> B = torch.randn(5, 4, 300)
|
| 146 |
+
>>> vanilla_bert_score(A, B)
|
| 147 |
+
"""
|
| 148 |
+
def numerator(pairwise_cos, dim, weights):
|
| 149 |
+
max_cos, _ = pairwise_cos.max(dim=dim)
|
| 150 |
+
return torch.sum(max_cos * weights, 1) # (max_cos * weights).sum(dim=1)
|
| 151 |
+
|
| 152 |
+
# get weights
|
| 153 |
+
weights_A, weights_B = get_weights(A, B, weight)
|
| 154 |
+
|
| 155 |
+
# Pairwise cosine
|
| 156 |
+
pairwise_cos = pairwise_cosine_matrix(A, B)
|
| 157 |
+
|
| 158 |
+
# Cmpute P, R, F
|
| 159 |
+
R = numerator(pairwise_cos, 2, weights_A) / torch.sum(weights_A, 1) # R は SubspaceJohnson の 左項
|
| 160 |
+
P = numerator(pairwise_cos, 1, weights_B) / torch.sum(weights_B, 1) # P は SubspaceJohnson の 右項
|
| 161 |
+
F = (2 * P * R) / (P + R)
|
| 162 |
+
return P, R, F
|
subspace/symbolic.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def symbolic_johnson(x, y):
|
| 2 |
+
"""
|
| 3 |
+
Classical Johnson similarity measure between two sets
|
| 4 |
+
:param x: list of words (strings) for the first sentence
|
| 5 |
+
:param y: list of words (strings) for the second sentence
|
| 6 |
+
:return: similarity score between two sentences
|
| 7 |
+
"""
|
| 8 |
+
if len(x) == 0 or len(y) == 0:
|
| 9 |
+
return 0.0
|
| 10 |
+
xs = set(x)
|
| 11 |
+
ys = set(y)
|
| 12 |
+
inter = xs & ys
|
| 13 |
+
return len(inter) / len(xs) + len(inter) / len(ys)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def symbolic_jaccard(x, y):
|
| 17 |
+
"""
|
| 18 |
+
Classical Jaccard similarity measure between two sets
|
| 19 |
+
:param x: list of words (strings) for the first sentence
|
| 20 |
+
:param y: list of words (strings) for the second sentence
|
| 21 |
+
:return: similarity score between two sentences
|
| 22 |
+
"""
|
| 23 |
+
if len(x) == 0 or len(y) == 0:
|
| 24 |
+
return 0.0
|
| 25 |
+
xs = set(x)
|
| 26 |
+
ys = set(y)
|
| 27 |
+
inter = xs & ys
|
| 28 |
+
union = xs | ys
|
| 29 |
+
return len(inter) / len(union)
|
subspace/tool.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import transformers
|
| 3 |
+
from transformers import AutoTokenizer, AutoModel
|
| 4 |
+
from numpy import ndarray
|
| 5 |
+
import numpy as np
|
| 6 |
+
from .similarity import subspace_johnson, subspace_bert_score, vanilla_bert_score
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MySimilarity:
|
| 10 |
+
def __init__(self, device='cpu', model_name_or_path='bert-base-uncased'):
|
| 11 |
+
# Set up model
|
| 12 |
+
self.device = device
|
| 13 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 14 |
+
self.model = AutoModel.from_pretrained(model_name_or_path)
|
| 15 |
+
self.model.eval()
|
| 16 |
+
self.model.to(device)
|
| 17 |
+
self.max_length = 128
|
| 18 |
+
|
| 19 |
+
def __call__(self, sentence1, sentence2, weight="L2"):
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def encode(self, sentence, return_numpy=False, batch_size=12):
|
| 24 |
+
|
| 25 |
+
single_sentence = False
|
| 26 |
+
if isinstance(sentence, str):
|
| 27 |
+
sentence = [sentence]
|
| 28 |
+
single_sentence = True
|
| 29 |
+
|
| 30 |
+
embedding_list = []
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0)
|
| 33 |
+
for batch_id in range(total_batch):
|
| 34 |
+
inputs = self.tokenizer(
|
| 35 |
+
sentence[batch_id*batch_size:(batch_id+1)*batch_size],
|
| 36 |
+
padding=True,
|
| 37 |
+
truncation=True,
|
| 38 |
+
max_length=self.max_length,
|
| 39 |
+
return_tensors="pt"
|
| 40 |
+
)
|
| 41 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 42 |
+
outputs = self.model(**inputs, return_dict=True)
|
| 43 |
+
|
| 44 |
+
embeddings = outputs.last_hidden_state.cpu()
|
| 45 |
+
embedding_list.append(embeddings)
|
| 46 |
+
|
| 47 |
+
embeddings = torch.cat(embedding_list, 0)
|
| 48 |
+
|
| 49 |
+
if return_numpy and not isinstance(embeddings, ndarray):
|
| 50 |
+
return embeddings.numpy()
|
| 51 |
+
return embeddings
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SubspaceJohnsonSimilarity(MySimilarity):
|
| 55 |
+
def __call__(self, sentence1, sentence2, weight="L2"):
|
| 56 |
+
hidden_states1 = self.encode(sentence1)
|
| 57 |
+
hidden_states2 = self.encode(sentence2)
|
| 58 |
+
return subspace_johnson(hidden_states1, hidden_states2, weight)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class SubspaceBERTScore(MySimilarity):
|
| 62 |
+
def __call__(self, sentence1, sentence2, weight="L2"):
|
| 63 |
+
hidden_states1 = self.encode(sentence1)
|
| 64 |
+
hidden_states2 = self.encode(sentence2)
|
| 65 |
+
return subspace_bert_score(hidden_states1, hidden_states2, weight)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class VanillaBERTScore(MySimilarity):
|
| 69 |
+
def __call__(self, sentence1, sentence2, weight="L2"):
|
| 70 |
+
hidden_states1 = self.encode(sentence1)
|
| 71 |
+
hidden_states2 = self.encode(sentence2)
|
| 72 |
+
return vanilla_bert_score(hidden_states1, hidden_states2, weight)
|