Eylon Caplan commited on
Commit
ddb7b62
·
1 Parent(s): c82bc86

Deploy app code targeting HF Storage Bucket

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ bm25_indexes/
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import os
4
+ from core_logic import (
5
+ query_bm25_index,
6
+ lift_at_k,
7
+ lift_ci,
8
+ compute_keyword_similarity
9
+ )
10
+
11
+ # Path to the index directory where the Hugging Face Storage Bucket is mounted.
12
+ # Assuming the bucket is mounted at /data in the Space settings.
13
+ INDEX_DIR = '/data/bm25_indexes'
14
+
15
+ def get_available_indices():
16
+ if os.path.exists(INDEX_DIR):
17
+ return [d for d in os.listdir(INDEX_DIR) if os.path.isdir(os.path.join(INDEX_DIR, d))]
18
+ return ["No indices found"]
19
+
20
+ def evaluate_keywords(index_name, target_demo, seed_words_str, generated_words_str):
21
+ try:
22
+ # Parse inputs
23
+ seed_words = [w.strip() for w in seed_words_str.split(",") if w.strip()]
24
+ generated_words = [w.strip() for w in generated_words_str.split(",") if w.strip()]
25
+
26
+ index_path = os.path.join(INDEX_DIR, index_name)
27
+
28
+ # 1. Compute BM25 Lifts for the GENERATED words
29
+ df_results = query_bm25_index(index_path, generated_words, doc_count=1000)
30
+
31
+ lift_100 = lift_at_k(df_results, target_demo, k=100)
32
+ pval_100, ci_lower_100, ci_upper_100 = lift_ci(df_results, target_demo, k=100)
33
+
34
+ lift_5_percent = lift_at_k(df_results, target_demo, k=0.05)
35
+ pval_5, ci_lower_5, ci_upper_5 = lift_ci(df_results, target_demo, k=0.05)
36
+
37
+ lift_text = (
38
+ f"**Lift@100:** {lift_100:.3f} (p={pval_100:.4f}, 95% CI: [{ci_lower_100:.3f}, {ci_upper_100:.3f}])\n"
39
+ f"**Lift@5%:** {lift_5_percent:.3f} (p={pval_5:.4f}, 95% CI: [{ci_lower_5:.3f}, {ci_upper_5:.3f}])"
40
+ )
41
+
42
+ # 2. Compute BERT Similarity
43
+ sim_metrics = compute_keyword_similarity(seed_words, generated_words, device='cpu')
44
+ sim_text = (
45
+ f"**Precision:** {sim_metrics['Precision']:.4f}\n"
46
+ f"**Recall:** {sim_metrics['Recall']:.4f}\n"
47
+ f"**F-Score:** {sim_metrics['F-Score']:.4f}"
48
+ )
49
+
50
+ # 3. Preview Top 10 hits
51
+ top_hits = df_results.head(10)[['id', 'score', 'demographic', 'content']]
52
+
53
+ return lift_text, sim_text, top_hits
54
+
55
+ except Exception as e:
56
+ return f"Error: {str(e)}", "", pd.DataFrame()
57
+
58
+ # Gradio Interface
59
+ with gr.Blocks(title="BM25 Splits Demo") as demo:
60
+ gr.Markdown("# 🚀 BM25 Target Demographic Evaluation Demo")
61
+ gr.Markdown("Test retrieved demographic splits against predefined seed keywords and BERT Subspace metrics.")
62
+
63
+ with gr.Row():
64
+ with gr.Column():
65
+ index_dropdown = gr.Dropdown(choices=get_available_indices(), label="Select BM25 Index")
66
+ target_demo_input = gr.Textbox(label="Target Demographic (e.g., 'jewish', 'black')", value="jewish")
67
+
68
+ seed_words_input = gr.Textbox(
69
+ label="Target Demographic Seed Words (Comma separated)",
70
+ value="the, be, to, of, and, a, in, that, have, I, it, for, not, on, with, he, as, you, do, at"
71
+ )
72
+ generated_words_input = gr.Textbox(
73
+ label="Your Subspace/Generated Keywords (Comma separated)",
74
+ value="church, jesus, christ, prayer"
75
+ )
76
+
77
+ submit_btn = gr.Button("Run Compute", variant="primary")
78
+
79
+ with gr.Column():
80
+ gr.Markdown("### 📊 Similarity Metrics (BERT-Score)")
81
+ sim_output = gr.Markdown("Waiting to run...")
82
+
83
+ gr.Markdown("### 📈 Lift Metrics (BM25)")
84
+ lift_output = gr.Markdown("Waiting to run...")
85
+
86
+ gr.Markdown("### 🔍 Top 10 Retrieved Hits")
87
+ table_output = gr.Dataframe()
88
+
89
+ submit_btn.click(
90
+ fn=evaluate_keywords,
91
+ inputs=[index_dropdown, target_demo_input, seed_words_input, generated_words_input],
92
+ outputs=[lift_output, sim_output, table_output]
93
+ )
94
+
95
+ if __name__ == "__main__":
96
+ demo.launch(server_name="0.0.0.0", server_port=7860)
core_logic.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import hashlib
4
+ import random
5
+ import pandas as pd
6
+ from typing import Union, Tuple
7
+ import torch
8
+
9
+ from pyserini.search.lucene import LuceneSearcher
10
+ from pyserini.index.lucene import Document
11
+ from pyserini.analysis import get_lucene_analyzer
12
+ from pyserini.pyclass import autoclass
13
+ from scipy.stats import hypergeom
14
+ from subspace.tool import SubspaceBERTScore
15
+
16
+ # ==============================================================================
17
+ # BM25 Search and Query Building
18
+ # ==============================================================================
19
+
20
+ def get_standard_query(query: str, field: str = "contents", analyzer=None):
21
+ """
22
+ Runs Lucene's StandardQueryParser to get a parsed query object.
23
+ """
24
+ if analyzer is None:
25
+ analyzer = get_lucene_analyzer()
26
+
27
+ JStandardQueryParser = autoclass('org.apache.lucene.queryparser.flexible.standard.StandardQueryParser')
28
+ query_parser = JStandardQueryParser()
29
+ query_parser.setAnalyzer(analyzer)
30
+
31
+ return query_parser.parse(query, field)
32
+
33
+ def query_bm25_index(index_path: str, keywords: list, doc_count: int = 1000) -> pd.DataFrame:
34
+ """Load index, run BM25 phrase search using custom HuggingFace analyzer, and return results."""
35
+ # 1. Load searcher
36
+ searcher = LuceneSearcher(index_path)
37
+
38
+ # 2. Load custom analyzer matching your index strategy
39
+ analyzer = get_lucene_analyzer(
40
+ language='hgf_tokenizer',
41
+ huggingFaceTokenizer='bert-base-uncased'
42
+ )
43
+
44
+ # 3. Create query string connecting keywords by OR (e.g., '"jesus" OR "christ"')
45
+ query_string = " OR ".join([f'"{kw}"' for kw in keywords])
46
+
47
+ # 4. Build standard lucene query using your custom querybuilder
48
+ phrase_q = get_standard_query(query_string, analyzer=analyzer)
49
+
50
+ # 5. Search
51
+ hits = searcher.search(phrase_q, doc_count)
52
+
53
+ # 6. Parse results
54
+ results = []
55
+ returned_ids = set()
56
+
57
+ for hit in hits:
58
+ returned_ids.add(hit.docid)
59
+ doc = Document(hit.lucene_document)
60
+ raw = doc.raw()
61
+ jd = json.loads(raw)
62
+
63
+ row = {
64
+ 'id': jd.get("id"),
65
+ 'content': jd.get("contents", ""),
66
+ 'score': hit.score
67
+ }
68
+
69
+ if "metadata" in jd and jd["metadata"]:
70
+ metadata = json.loads(jd["metadata"])
71
+ row.update(metadata)
72
+
73
+ results.append(row)
74
+
75
+ returned_ext_ids = {r['id'] for r in results}
76
+
77
+ # Pad with random unretrieved items exactly as required
78
+ if len(results) < doc_count:
79
+ needed = doc_count - len(results)
80
+ total = searcher.num_docs
81
+
82
+ # build a list of internal docnums whose external ID wasn't already returned
83
+ pool = []
84
+ for docnum in range(total):
85
+ lucene_doc = searcher.doc(docnum)
86
+ doc = Document(lucene_doc)
87
+ jd = json.loads(doc.raw())
88
+ ext_id = jd.get("id")
89
+ if ext_id not in returned_ext_ids:
90
+ pool.append(docnum)
91
+
92
+ # deterministically shuffle by query
93
+ md5 = hashlib.md5(query_string.encode("utf-8")).hexdigest()
94
+ seed = int(md5, 16) % 2**32
95
+ rng = random.Random(seed)
96
+ rng.shuffle(pool)
97
+
98
+ # pull 'needed' more docs
99
+ for docnum in pool[:needed]:
100
+ lucene_doc = searcher.doc(docnum)
101
+ doc = Document(lucene_doc)
102
+ raw = doc.raw()
103
+ jd = json.loads(raw)
104
+
105
+ row = {
106
+ "id": jd.get("id"),
107
+ "content": jd.get("contents", ""),
108
+ "score": None
109
+ }
110
+ if "metadata" in jd and jd["metadata"]:
111
+ metadata = json.loads(jd["metadata"])
112
+ row.update(metadata)
113
+
114
+ results.append(row)
115
+
116
+ return pd.DataFrame(results)
117
+
118
+ # ==============================================================================
119
+ # Evaluation Metrics (Precision/Lift)
120
+ # ==============================================================================
121
+
122
+ def _resolve_k(df, k):
123
+ """Convert float percentages to absolute k or return k as an int."""
124
+ if isinstance(k, float) and 0.0 < k <= 1.0:
125
+ return int(len(df) * k)
126
+ return int(k)
127
+
128
+ def precision_at_k(df: pd.DataFrame, correct_demographic: str, k: Union[int, float]) -> float:
129
+ """Calculate precision at k for a target demographic."""
130
+ rel = (df['demographic'] == correct_demographic).astype(int)
131
+ k_abs = _resolve_k(df, k)
132
+ if k_abs <= 0:
133
+ return 0.0
134
+ return rel.iloc[:k_abs].sum() / float(k_abs)
135
+
136
+ def lift_at_k(df: pd.DataFrame, correct_demographic: str, k: Union[int, float]) -> float:
137
+ """Lift@k: ratio of precision@k to the overall proportion of relevant items."""
138
+ k_abs = _resolve_k(df, k)
139
+ if k_abs <= 0 or len(df) == 0:
140
+ return 0.0
141
+
142
+ precision_k = precision_at_k(df, correct_demographic, k)
143
+ rel = (df['demographic'] == correct_demographic).astype(int)
144
+ overall_proportion = rel.sum() / float(len(df))
145
+
146
+ if overall_proportion == 0:
147
+ return 0.0
148
+
149
+ return precision_k / overall_proportion
150
+
151
+ def hypergeometric_significance_test(df: pd.DataFrame, correct_demographic: str, k: Union[int, float], alpha: float = 0.05) -> Tuple[float, Tuple[int, int], Tuple[float, float]]:
152
+ """Hypergeometric statistical significance test for the retrieval."""
153
+ n = _resolve_k(df, k)
154
+ N = len(df)
155
+
156
+ rel = (df['demographic'] == correct_demographic).astype(int)
157
+ K = rel.sum()
158
+ k_obs = rel.iloc[:n].sum()
159
+
160
+ if K == 0 or n <= 0:
161
+ return 0.0, (0, 0), (0.0, 0.0)
162
+
163
+ p_value = hypergeom.sf(k_obs - 1, N, K, n)
164
+ L = int(hypergeom.ppf(alpha/2, N, K, n))
165
+ U = int(hypergeom.isf(alpha/2, N, K, n))
166
+
167
+ return p_value, (L, U), (L / n, U / n)
168
+
169
+ def lift_ci(df: pd.DataFrame, correct_demographic: str, k: Union[int, float], alpha: float = 0.05) -> Tuple[float, float, float]:
170
+ """Calculate confidence interval for lift@k using hypergeometric distribution."""
171
+ n = _resolve_k(df, k)
172
+ N = len(df)
173
+
174
+ rel = (df['demographic'] == correct_demographic).astype(int)
175
+ K = rel.sum()
176
+ overall_proportion = K / float(N)
177
+
178
+ if K == 0 or n <= 0 or overall_proportion == 0:
179
+ return 0.0, 0.0, 0.0
180
+
181
+ pval, (L, U), _ = hypergeometric_significance_test(df, correct_demographic, k, alpha)
182
+ lower_bound_lift = (L / n) / overall_proportion
183
+ upper_bound_lift = (U / n) / overall_proportion
184
+
185
+ return pval, lower_bound_lift, upper_bound_lift
186
+
187
+ # ==============================================================================
188
+ # Keyword Similarity (SubspaceBERTScore)
189
+ # ==============================================================================
190
+
191
+ def compute_keyword_similarity(set1: list, set2: list, device: str = None) -> dict:
192
+ """
193
+ Computes precision, recall, and F-score similarity metrics between two keyword sets.
194
+ Mirrors the subspace-based BERT scoring logic handling keyword lists.
195
+ """
196
+ if device is None:
197
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
198
+
199
+ print(f"Initializing BERT model on {device}...")
200
+ scorer = SubspaceBERTScore(device=device, model_name_or_path='bert-base-uncased')
201
+
202
+ sentence_1 = [", ".join(set1)]
203
+ sentence_2 = [", ".join(set2)]
204
+
205
+ scores = scorer(sentence_1, sentence_2)
206
+
207
+ return {
208
+ 'Precision': scores[0].item(),
209
+ 'Recall': scores[1].item(),
210
+ 'F-Score': scores[2].item()
211
+ }
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ default-jre
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pyserini==0.44.0
2
+ pandas==2.2.3
3
+ scipy==1.15.2
4
+ transformers==4.53.2
5
+ torch
6
+ gradio
subspace/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .similarity import subspace_johnson
2
+ from .similarity import vanilla_bert_score
3
+ from .similarity import subspace_bert_score
4
+
5
+ # Other metrics
6
+ from .fuzzy import *
7
+ from .symbolic import *
8
+ #from .optimal_transport import *
9
+ #from .grassmannian import *
subspace/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (322 Bytes). View file
 
subspace/__pycache__/fuzzy.cpython-310.pyc ADDED
Binary file (1.15 kB). View file
 
subspace/__pycache__/similarity.cpython-310.pyc ADDED
Binary file (5.61 kB). View file
 
subspace/__pycache__/symbolic.cpython-310.pyc ADDED
Binary file (1.09 kB). View file
 
subspace/__pycache__/tool.cpython-310.pyc ADDED
Binary file (2.84 kB). View file
 
subspace/fuzzy.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 Babylon Partners. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+
18
+
19
+ def fuzzify(s, u):
20
+ """
21
+ Sentence fuzzifier.
22
+ Computes membership vector for the sentence S with respect to the
23
+ universe U
24
+ :param s: list of word embeddings for the sentence
25
+ :param u: the universe matrix U with shape (K, d)
26
+ :return: membership vectors for the sentence
27
+ """
28
+ f_s = np.dot(s, u.T)
29
+ m_s = np.max(f_s, axis=0)
30
+ m_s = np.maximum(m_s, 0, m_s)
31
+ return m_s
32
+
33
+
34
+ def dynamax_jaccard(x, y):
35
+ """
36
+ DynaMax-Jaccard similarity measure between two sentences
37
+ :param x: list of word embeddings for the first sentence
38
+ :param y: list of word embeddings for the second sentence
39
+ :return: similarity score between the two sentences
40
+ """
41
+ u = np.vstack((x, y))
42
+ m_x = fuzzify(x, u)
43
+ m_y = fuzzify(y, u)
44
+
45
+ m_inter = np.sum(np.minimum(m_x, m_y))
46
+ m_union = np.sum(np.maximum(m_x, m_y))
47
+ return m_inter / m_union
subspace/grassmannian.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+
4
+
5
+ def grassmann_distance(U, V):
6
+ """ Compute geodesic distance for grassmann manifold
7
+
8
+ Args:
9
+ U, V: A matrix of bases of a linear subspace
10
+ Return:
11
+ grassmann distance
12
+ See Also:
13
+ scipy.linalg.subspace_angles
14
+ Example:
15
+ >>> U = np.array([[1,0,0], [1,1,1]])
16
+ >>> V = np.array([[0,1,0], [1,1,1]])
17
+ >>> grassmann_distance(U, V)
18
+ """
19
+ # compute the canonical angles
20
+ s = scipy.linalg.subspace_angles(U.T, V.T)
21
+ # grassmann distance
22
+ return sum(s * s)
23
+
24
+
25
+ def grassmann_similarity(x, y):
26
+ return -grassmann_distance(x, y)
subspace/legacy_operations/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .operations import *
subspace/legacy_operations/operations.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.linalg import orth
3
+
4
+ def subspace_np(A):
5
+ """ Compute orthonormal bases of the subspace
6
+ Args:
7
+ A: bases of the linear subspace (n_bases, dim)
8
+ Return:
9
+ Orthonormal bases
10
+ Example:
11
+ >>> A = np.random.random_sample((10, 300))
12
+ >>> subspace_np(A)
13
+ """
14
+ return orth(A.T).T
15
+
16
+
17
+ def intersection_np(SA, SB, threshold=1e-2):
18
+ """ Compute bases of the intersection
19
+ Args:
20
+ SA, SB: bases of the linear subspace (n_bases, dim)
21
+ Return:
22
+ Bases of intersection
23
+ Example:
24
+ >>> A = np.random.random_sample((10, 300))
25
+ >>> B = np.random.random_sample((20, 300))
26
+ >>> intersection_np(A, B)
27
+ """
28
+ assert threshold > 1e-6
29
+
30
+ if SA.shape[0] > SB.shape[0]:
31
+ return intersection_np(SB, SA, threshold)
32
+
33
+ # orthonormalize
34
+ SA = subspace_np(SA)
35
+ SB = subspace_np(SB)
36
+
37
+ # compute canonical angles
38
+ u, s, v = np.linalg.svd(SA @ SB.T)
39
+
40
+ # extract the basis that the canonical angle is zero
41
+ u = u[:, np.abs(s - 1.0) < threshold]
42
+ return (SA.T @ u).T
43
+
44
+
45
+ def sum_space_np(SA, SB):
46
+ """ Compute bases of the sum space
47
+ Args:
48
+ SA, SB: bases of the linear subspace (n_bases, dim)
49
+ Return:
50
+ Bases of sum space
51
+ Example:
52
+ >>> A = np.random.random_sample((10, 300))
53
+ >>> B = np.random.random_sample((20, 300))
54
+ >>> sum_space_np(A, B)
55
+ """
56
+ M = np.concatenate([SA, SB], axis=0)
57
+ return subspace_np(M)
58
+
59
+
60
+ def orthogonal_complement_np(SA, threshold=1e-2):
61
+ """ Compute bases of the orthogonal complement
62
+ Args:
63
+ SA: bases of the linear subspace (n_bases, dim)
64
+ Return:
65
+ Bases of the orthogonal complement
66
+ Example:
67
+ >>> A = np.random.random_sample((10, 300))
68
+ >>> orthogonal_complement_np(A)
69
+ """
70
+ assert threshold > 1e-6
71
+ u, s, v = np.linalg.svd(SA.T)
72
+ # compute rank
73
+ rank = (s > threshold).sum()
74
+ return u[:, rank:].T
75
+
76
+
77
+ def soft_membership_np(A, v):
78
+ """ Compute membership degree of the vector v for the subspace A
79
+ Args:
80
+ A: bases of the linear subspace (n_bases, dim)
81
+ v: vector (dim,)
82
+ Return:
83
+ soft membership degree
84
+ Example:
85
+ >>> A = np.array([[1,0,0], [0,1,0]])
86
+ >>> v = np.array([1,0,0])
87
+ >>> soft_membership_np(A, v)
88
+ 1.0
89
+ >>> A = np.array([[1,0,0], [0,1,0]])
90
+ >>> v = np.array([0,0,1])
91
+ >>> soft_membership_np(A, v)
92
+ 0.0
93
+ """
94
+ v = v.reshape(1, len(v))
95
+ v = subspace_np(v)
96
+ A = subspace_np(A)
97
+
98
+ # The cosine of the angles between a subspace and a vector are singular values
99
+ u, s, v = np.linalg.svd(A @ v.T)
100
+ s[s > 1] = 1
101
+
102
+ # Return the maximum cosine of the canonical angles, i.e., the soft membership.
103
+ return np.max(s)
subspace/operations.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def subspace(A: torch.Tensor) -> torch.Tensor:
4
+ """
5
+ Compute orthonormal bases of the subspace
6
+ Args:
7
+ A: bases of the linear subspace (n_bases, dim)
8
+ Return:
9
+ Orthonormal bases
10
+ Example:
11
+ >>> A = torch.rand(10, 300)
12
+ >>> subspace(A)
13
+ """
14
+ return torch.linalg.qr(A.t()).Q.t()
15
+
16
+
17
+ def intersection(SA: torch.Tensor, SB: torch.Tensor, threshold: float = 1e-2) -> torch.Tensor:
18
+ """
19
+ Compute bases of the intersection
20
+ Args:
21
+ SA, SB: bases of the linear subspace (n_bases, dim)
22
+ Return:
23
+ Bases of intersection
24
+ Example:
25
+ >>> A = torch.rand(10, 300)
26
+ >>> B = torch.rand(20, 300)
27
+ >>> intersection(A, B)
28
+ """
29
+ assert threshold > 1e-6
30
+
31
+ if SA.shape[0] > SB.shape[0]:
32
+ return intersection(SB, SA, threshold)
33
+
34
+ # orthonormalize
35
+ SA = subspace(SA)
36
+ SB = subspace(SB)
37
+
38
+ # compute canonical angles
39
+ u, s, v = torch.linalg.svd(SA @ SB.t())
40
+
41
+ # extract the basis that the canonical angle is zero
42
+ u = u[:, (s - 1.0).abs() < threshold]
43
+ return (SA.t() @ u).t()
44
+
45
+
46
+ def sum_space(SA: torch.Tensor, SB: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Compute bases of the sum space
49
+ Args:
50
+ SA, SB: bases of the linear subspace (n_bases, dim)
51
+ Return:
52
+ Bases of sum space
53
+ Example:
54
+ >>> A = torch.rand(10, 300)
55
+ >>> B = torch.rand(20, 300)
56
+ >>> sum_space(A, B)
57
+ """
58
+ M = torch.cat([SA, SB], dim=0)
59
+ return subspace(M)
60
+
61
+
62
+ def orthogonal_complement(SA: torch.Tensor, threshold: float = 1e-2) -> torch.Tensor:
63
+ """
64
+ Compute bases of the orthogonal complement
65
+ Args:
66
+ SA: bases of the linear subspace (n_bases, dim)
67
+ Return:
68
+ Bases of the orthogonal complement
69
+ Example:
70
+ >>> A = torch.rand(10, 300)
71
+ >>> orthogonal_complement(A)
72
+ """
73
+ assert threshold > 1e-6
74
+ u, s, v = torch.linalg.svd(SA.t())
75
+ # compute rank
76
+ rank = (s > threshold).sum()
77
+ return u[:, rank:].T
78
+
79
+
80
+ def soft_membership(A: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
81
+ """
82
+ Compute membership degree of the vector v for the subspace A
83
+ Args:
84
+ A: bases of the linear subspace (n_bases, dim)
85
+ v: vector (dim,)
86
+ Return:
87
+ soft membership degree
88
+ Example:
89
+ >>> A = torch.tensor([[1,0,0], [0,1,0]])
90
+ >>> v = torch.tensor([1,0,0])
91
+ >>> soft_membership(A, v)
92
+ 1.0
93
+ >>> A = torch.tensor([[1,0,0], [0,1,0]])
94
+ >>> v = torch.tensor([0,0,1])
95
+ >>> soft_membership(A, v)
96
+ 0.0
97
+ """
98
+ v = v.reshape(1, len(v))
99
+ v = subspace(v)
100
+ A = subspace(A)
101
+
102
+ # The cosine of the angles between a subspace and a vector are singular values
103
+ u, s, v = torch.linalg.svd(A @ v.t())
104
+ s[s > 1] = 1
105
+
106
+ # Return the maximum cosine of the canonical angles, i.e., the soft membership.
107
+ return torch.max(s)
108
+
subspace/optimal_transport.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://kexue.fm/archives/7388
2
+
3
+ import numpy as np
4
+ from scipy.optimize import linprog
5
+
6
+
7
+ def wasserstein_distance(p, q, D):
8
+ A_eq = []
9
+ for i in range(len(p)):
10
+ A = np.zeros_like(D)
11
+ A[i, :] = 1
12
+ A_eq.append(A.reshape(-1))
13
+ for i in range(len(q)):
14
+ A = np.zeros_like(D)
15
+ A[:, i] = 1
16
+ A_eq.append(A.reshape(-1))
17
+ A_eq = np.array(A_eq)
18
+ b_eq = np.concatenate([p, q])
19
+ D = D.reshape(-1)
20
+ result = linprog(D, A_eq=A_eq[:-1], b_eq=b_eq[:-1])
21
+ return result.fun
22
+
23
+
24
+ def word_rotator_distance(x, y):
25
+ x_norm = (x**2).sum(axis=1, keepdims=True)**0.5
26
+ y_norm = (y**2).sum(axis=1, keepdims=True)**0.5
27
+ p = x_norm[:, 0] / x_norm.sum()
28
+ q = y_norm[:, 0] / y_norm.sum()
29
+ D = 1 - np.dot(x / x_norm, (y / y_norm).T)
30
+ return wasserstein_distance(p, q, D)
31
+
32
+
33
+ def word_mover_distance(x, y):
34
+ p = np.ones(x.shape[0]) / x.shape[0]
35
+ q = np.ones(y.shape[0]) / y.shape[0]
36
+ D = np.sqrt(np.square(x[:, None] - y[None, :]).mean(axis=2))
37
+ return wasserstein_distance(p, q, D)
38
+
39
+
40
+ def word_rotator_similarity(x, y):
41
+ return 1 - word_rotator_distance(x, y)
42
+
43
+
44
+ def word_mover_similarity(x, y):
45
+ return 1 - word_mover_distance(x, y)
46
+
47
+
subspace/similarity.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def get_weights(A, B, weight):
4
+ # get weights
5
+ if weight == "L2":
6
+ weights_A = torch.linalg.norm(A, dim=2)
7
+ weights_B = torch.linalg.norm(B, dim=2)
8
+ elif weight == "L1":
9
+ weights_A = torch.linalg.norm(A, dim=2, ord=1)
10
+ weights_B = torch.linalg.norm(B, dim=2, ord=1)
11
+ elif weight == "no":
12
+ weights_A = torch.ones(A.size(0), A.size(1)).to(A.device)
13
+ weights_B = torch.ones(B.size(0), B.size(1)).to(B.device)
14
+ else:
15
+ raise NotImplementedError
16
+ return weights_A, weights_B
17
+
18
+
19
+ def pairwise_cosine_matrix(matrix1, matrix2):
20
+ dot = torch.matmul(matrix1, matrix2.transpose(1, 2))
21
+ matrix1_norm = torch.norm(matrix1, dim=-1, keepdim=True)
22
+ matrix2_norm = torch.norm(matrix2, dim=-1, keepdim=True)
23
+ norm = torch.matmul(matrix1_norm, matrix2_norm.transpose(1, 2))
24
+ return dot / norm
25
+
26
+
27
+ def subspace_batch(A):
28
+ """ Return the matrix of the subspace for a batch of matrices
29
+ Arg:
30
+ A: Bases of a linear subspace (batchsize, num_bases, emb_dim)
31
+ Return:
32
+ S: Orthonormalized bases of a linear subspace (batchsize, num_bases, emb_dim)
33
+ Example:
34
+ >>> A = torch.randn(5, 4, 300)
35
+ >>> subspace_batch(A)
36
+ """
37
+ # orthonormalize
38
+ S, _ = torch.linalg.qr(torch.transpose(A, 1, 2))
39
+ return torch.transpose(S, 1, 2)
40
+
41
+
42
+ @torch.jit.script
43
+ def soft_membership_batch(S, v):
44
+ """ Compute soft membership degree between a subspace and a vector for a batch of vectors
45
+
46
+ Args:
47
+ S: Orthonormalized bases of a linear subspace (batchsize, num_bases, emb_dim)
48
+ v: vector (batchsize, emb_dim)
49
+ Return:
50
+ soft_membership degree (batchsize,)
51
+ Example:
52
+ >>> S = torch.randn(5, 4, 300)
53
+ >>> v = torch.randn(5, 300)
54
+ >>> soft_membership_batch(S, v)
55
+ """
56
+ # normalize
57
+ v = torch.nn.functional.normalize(v)
58
+ v = v.view(v.size(0), v.size(1), 1)
59
+
60
+ # compute SVD for cos(theta)
61
+ m = torch.matmul(S, v)
62
+ s = torch.linalg.svdvals(m.float()) # s is the sequence of cos(theta_i)
63
+ return torch.mean(s, 1)
64
+
65
+
66
+ def subspace_johnson(A, B, weight="L2"):
67
+ """ Compute similarity between two vector sets (sentences)
68
+ Args:
69
+ A: Matrix of word embeddings for the first sentence
70
+ (batchsize, num_bases, dim)
71
+ B: Matrix of word embeddings for the second sentence
72
+ (batchsize, num_bases, dim)
73
+ Return:
74
+ similarity between A and B (batchsize,)
75
+ Example:
76
+ >>> A = torch.randn(5, 3, 300)
77
+ >>> B = torch.randn(5, 4, 300)
78
+ >>> subspace_johnson(A, B)
79
+ """
80
+ def numerator(U, V, weights):
81
+ """
82
+ U should be a matrix of word embeddings
83
+ V should be a matrix of orthonormalized bases
84
+ """
85
+ softm = torch.stack([soft_membership_batch(V, vec)
86
+ for vec in torch.transpose(U, 0, 1)])
87
+ softm = torch.transpose(softm, 0, 1)
88
+ return torch.sum(softm * weights, 1)
89
+
90
+ # get weights
91
+ weights_A, weights_B = get_weights(A, B, weight)
92
+
93
+ # compute similarity
94
+ x = numerator(A, subspace_batch(B), weights_A) / torch.sum(weights_A, 1)
95
+ y = numerator(B, subspace_batch(A), weights_B) / torch.sum(weights_B, 1)
96
+ return x + y
97
+
98
+
99
+
100
+ def subspace_bert_score(A, B, weight="L2"):
101
+ """ Compute similarity between two vector sets (sentences)
102
+ Args:
103
+ A: Matrix of word embeddings for the first sentence
104
+ (batchsize, num_bases, dim)
105
+ B: Matrix of word embeddings for the second sentence
106
+ (batchsize, num_bases, dim)
107
+ Return:
108
+ similarity between A and B (batchsize,)
109
+ Example:
110
+ >>> A = torch.randn(5, 3, 300)
111
+ >>> B = torch.randn(5, 4, 300)
112
+ >>> subspace_bert_score(A, B)
113
+ """
114
+ def numerator(U, V, weights):
115
+ """
116
+ U should be a matrix of word embeddings
117
+ V should be a matrix of orthonormalized bases
118
+ """
119
+ softm = torch.stack([soft_membership_batch(V, vec)
120
+ for vec in torch.transpose(U, 0, 1)])
121
+ softm = torch.transpose(softm, 0, 1)
122
+ return torch.sum(softm * weights, 1)
123
+
124
+ # get weights
125
+ weights_A, weights_B = get_weights(A, B, weight)
126
+
127
+ # Cmpute P, R, F
128
+ R = numerator(A, subspace_batch(B), weights_A) / torch.sum(weights_A, 1) # R is the left term of SubspaceJohnson
129
+ P = numerator(B, subspace_batch(A), weights_B) / torch.sum(weights_B, 1) # P is the right term of SubspaceJohnson
130
+ F = (2 * P * R) / (P + R)
131
+ return P, R, F
132
+
133
+
134
+ def vanilla_bert_score(A, B, weight="L2"):
135
+ """ Compute similarity between two vector sets (sentences)
136
+ Args:
137
+ A: Matrix of word embeddings for the first sentence
138
+ (batchsize, num_bases, dim)
139
+ B: Matrix of word embeddings for the second sentence
140
+ (batchsize, num_bases, dim)
141
+ Return:
142
+ similarity between A and B (batchsize,)
143
+ Example:
144
+ >>> A = torch.randn(5, 3, 300)
145
+ >>> B = torch.randn(5, 4, 300)
146
+ >>> vanilla_bert_score(A, B)
147
+ """
148
+ def numerator(pairwise_cos, dim, weights):
149
+ max_cos, _ = pairwise_cos.max(dim=dim)
150
+ return torch.sum(max_cos * weights, 1) # (max_cos * weights).sum(dim=1)
151
+
152
+ # get weights
153
+ weights_A, weights_B = get_weights(A, B, weight)
154
+
155
+ # Pairwise cosine
156
+ pairwise_cos = pairwise_cosine_matrix(A, B)
157
+
158
+ # Cmpute P, R, F
159
+ R = numerator(pairwise_cos, 2, weights_A) / torch.sum(weights_A, 1) # R は SubspaceJohnson の 左項
160
+ P = numerator(pairwise_cos, 1, weights_B) / torch.sum(weights_B, 1) # P は SubspaceJohnson の 右項
161
+ F = (2 * P * R) / (P + R)
162
+ return P, R, F
subspace/symbolic.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def symbolic_johnson(x, y):
2
+ """
3
+ Classical Johnson similarity measure between two sets
4
+ :param x: list of words (strings) for the first sentence
5
+ :param y: list of words (strings) for the second sentence
6
+ :return: similarity score between two sentences
7
+ """
8
+ if len(x) == 0 or len(y) == 0:
9
+ return 0.0
10
+ xs = set(x)
11
+ ys = set(y)
12
+ inter = xs & ys
13
+ return len(inter) / len(xs) + len(inter) / len(ys)
14
+
15
+
16
+ def symbolic_jaccard(x, y):
17
+ """
18
+ Classical Jaccard similarity measure between two sets
19
+ :param x: list of words (strings) for the first sentence
20
+ :param y: list of words (strings) for the second sentence
21
+ :return: similarity score between two sentences
22
+ """
23
+ if len(x) == 0 or len(y) == 0:
24
+ return 0.0
25
+ xs = set(x)
26
+ ys = set(y)
27
+ inter = xs & ys
28
+ union = xs | ys
29
+ return len(inter) / len(union)
subspace/tool.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from numpy import ndarray
5
+ import numpy as np
6
+ from .similarity import subspace_johnson, subspace_bert_score, vanilla_bert_score
7
+
8
+
9
+ class MySimilarity:
10
+ def __init__(self, device='cpu', model_name_or_path='bert-base-uncased'):
11
+ # Set up model
12
+ self.device = device
13
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
14
+ self.model = AutoModel.from_pretrained(model_name_or_path)
15
+ self.model.eval()
16
+ self.model.to(device)
17
+ self.max_length = 128
18
+
19
+ def __call__(self, sentence1, sentence2, weight="L2"):
20
+ pass
21
+
22
+
23
+ def encode(self, sentence, return_numpy=False, batch_size=12):
24
+
25
+ single_sentence = False
26
+ if isinstance(sentence, str):
27
+ sentence = [sentence]
28
+ single_sentence = True
29
+
30
+ embedding_list = []
31
+ with torch.no_grad():
32
+ total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0)
33
+ for batch_id in range(total_batch):
34
+ inputs = self.tokenizer(
35
+ sentence[batch_id*batch_size:(batch_id+1)*batch_size],
36
+ padding=True,
37
+ truncation=True,
38
+ max_length=self.max_length,
39
+ return_tensors="pt"
40
+ )
41
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
42
+ outputs = self.model(**inputs, return_dict=True)
43
+
44
+ embeddings = outputs.last_hidden_state.cpu()
45
+ embedding_list.append(embeddings)
46
+
47
+ embeddings = torch.cat(embedding_list, 0)
48
+
49
+ if return_numpy and not isinstance(embeddings, ndarray):
50
+ return embeddings.numpy()
51
+ return embeddings
52
+
53
+
54
+ class SubspaceJohnsonSimilarity(MySimilarity):
55
+ def __call__(self, sentence1, sentence2, weight="L2"):
56
+ hidden_states1 = self.encode(sentence1)
57
+ hidden_states2 = self.encode(sentence2)
58
+ return subspace_johnson(hidden_states1, hidden_states2, weight)
59
+
60
+
61
+ class SubspaceBERTScore(MySimilarity):
62
+ def __call__(self, sentence1, sentence2, weight="L2"):
63
+ hidden_states1 = self.encode(sentence1)
64
+ hidden_states2 = self.encode(sentence2)
65
+ return subspace_bert_score(hidden_states1, hidden_states2, weight)
66
+
67
+
68
+ class VanillaBERTScore(MySimilarity):
69
+ def __call__(self, sentence1, sentence2, weight="L2"):
70
+ hidden_states1 = self.encode(sentence1)
71
+ hidden_states2 = self.encode(sentence2)
72
+ return vanilla_bert_score(hidden_states1, hidden_states2, weight)