AmrYassinIsFree commited on
Commit
1587b68
Β·
1 Parent(s): bf74331

streamlit app and publishing to HF

Browse files
Files changed (4) hide show
  1. .streamlit/config.toml +9 -0
  2. README.md +12 -0
  3. app.py +288 -0
  4. requirements.txt +1 -0
.streamlit/config.toml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor = "#4C72B0"
3
+ backgroundColor = "#0E1117"
4
+ secondaryBackgroundColor = "#1A1D23"
5
+ textColor = "#FAFAFA"
6
+
7
+ [server]
8
+ maxUploadSize = 50
9
+ enableXsrfProtection = true
README.md CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  # embedding-bench
2
 
3
  Compare text embedding models across retrieval quality, inference speed, and memory footprint. Everything runs locally β€” no external API calls.
 
1
+ ---
2
+ title: Embedding Bench
3
+ emoji: πŸ“
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: "1.56.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
  # embedding-bench
14
 
15
  Compare text embedding models across retrieval quality, inference speed, and memory footprint. Everything runs locally β€” no external API calls.
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ import csv
5
+ import time
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import streamlit as st
10
+
11
+ from corpus import build_corpus
12
+ from dataset_config import DATASET_PRESETS, DatasetConfig
13
+ from evals.quality import evaluate_quality
14
+ from evals.speed import evaluate_speed
15
+ from models import REGISTRY, ModelConfig
16
+ from wrapper import load_model
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Page config
20
+ # ---------------------------------------------------------------------------
21
+ st.set_page_config(
22
+ page_title="Embedding Bench",
23
+ page_icon="πŸ“",
24
+ layout="wide",
25
+ )
26
+
27
+ st.title("πŸ“ Embedding Bench")
28
+ st.caption("Compare text embedding models on quality, speed & memory β€” all in your browser.")
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Sidebar β€” configuration
32
+ # ---------------------------------------------------------------------------
33
+ st.sidebar.header("Models")
34
+ available_models = list(REGISTRY.keys())
35
+ selected_models = st.sidebar.multiselect(
36
+ "Select models",
37
+ available_models,
38
+ default=["mpnet", "bge-small"] if len(available_models) >= 2 else available_models[:1],
39
+ )
40
+
41
+ st.sidebar.header("Datasets")
42
+ available_datasets = list(DATASET_PRESETS.keys())
43
+ selected_datasets = st.sidebar.multiselect(
44
+ "Select dataset presets",
45
+ available_datasets,
46
+ default=["sts"],
47
+ )
48
+
49
+ max_pairs = st.sidebar.number_input(
50
+ "Max pairs per dataset",
51
+ min_value=100,
52
+ max_value=50000,
53
+ value=1000,
54
+ step=100,
55
+ help="Limits the number of pairs evaluated. Keep low for large datasets.",
56
+ )
57
+
58
+ st.sidebar.header("Speed & Memory")
59
+ run_speed = st.sidebar.checkbox("Run speed benchmark", value=False)
60
+ run_memory = st.sidebar.checkbox("Run memory benchmark", value=False)
61
+
62
+ corpus_size = 500
63
+ num_runs = 3
64
+ batch_size = 64
65
+ if run_speed or run_memory:
66
+ corpus_size = st.sidebar.number_input("Corpus size", 100, 10000, 500, step=100)
67
+ batch_size = st.sidebar.number_input("Batch size", 8, 512, 64, step=8)
68
+ if run_speed:
69
+ num_runs = st.sidebar.number_input("Speed runs", 1, 10, 3)
70
+
71
+ # ---------------------------------------------------------------------------
72
+ # Helpers
73
+ # ---------------------------------------------------------------------------
74
+
75
+ @st.cache_resource(show_spinner="Loading model...")
76
+ def get_model(model_key: str):
77
+ cfg = REGISTRY[model_key]
78
+ return load_model(cfg)
79
+
80
+
81
+ def flatten_result(r: dict) -> dict:
82
+ flat = {"Model": r["name"]}
83
+ for ds_key, metrics in r.get("quality", {}).items():
84
+ for metric_name, value in metrics.items():
85
+ flat[f"{ds_key}/{metric_name}"] = value
86
+ speed = r.get("speed")
87
+ if speed:
88
+ flat["Speed (sent/s)"] = speed["sentences_per_second"]
89
+ flat["Median Time (s)"] = speed["median_seconds"]
90
+ mem = r.get("memory_mb")
91
+ if mem is not None:
92
+ flat["Memory (MB)"] = mem
93
+ return flat
94
+
95
+
96
+ def results_to_csv(results: list[dict]) -> str:
97
+ rows = [flatten_result(r) for r in results]
98
+ fieldnames = list(rows[0].keys())
99
+ for row in rows[1:]:
100
+ for k in row:
101
+ if k not in fieldnames:
102
+ fieldnames.append(k)
103
+ buf = io.StringIO()
104
+ writer = csv.DictWriter(buf, fieldnames=fieldnames)
105
+ writer.writeheader()
106
+ writer.writerows(rows)
107
+ return buf.getvalue()
108
+
109
+
110
+ # ---------------------------------------------------------------------------
111
+ # Run benchmark
112
+ # ---------------------------------------------------------------------------
113
+ if not selected_models:
114
+ st.warning("Select at least one model from the sidebar.")
115
+ st.stop()
116
+
117
+ if not selected_datasets:
118
+ st.warning("Select at least one dataset from the sidebar.")
119
+ st.stop()
120
+
121
+ run_btn = st.sidebar.button("πŸš€ Run Benchmark", type="primary", use_container_width=True)
122
+
123
+ if run_btn:
124
+ ds_configs = [DATASET_PRESETS[k] for k in selected_datasets]
125
+ results = []
126
+ progress = st.progress(0, text="Starting...")
127
+ total_steps = len(selected_models) * (len(ds_configs) + int(run_speed) + int(run_memory))
128
+ step = 0
129
+
130
+ for model_key in selected_models:
131
+ cfg = REGISTRY[model_key]
132
+ result: dict = {"name": cfg.name, "is_baseline": cfg.is_baseline}
133
+
134
+ # Quality
135
+ model = get_model(model_key)
136
+ quality_results = {}
137
+ for ds_cfg in ds_configs:
138
+ ds_key = ds_cfg.name.split("/")[-1]
139
+ step += 1
140
+ progress.progress(
141
+ step / total_steps,
142
+ text=f"Evaluating {cfg.name} on {ds_key}...",
143
+ )
144
+ quality_results[ds_key] = evaluate_quality(model, ds_cfg, max_pairs=max_pairs)
145
+ result["quality"] = quality_results
146
+
147
+ # Speed
148
+ if run_speed:
149
+ step += 1
150
+ progress.progress(step / total_steps, text=f"Speed benchmark: {cfg.name}...")
151
+ corpus = build_corpus(corpus_size, ds_configs[0])
152
+ result["speed"] = evaluate_speed(model, corpus, num_runs=num_runs, batch_size=batch_size)
153
+
154
+ # Memory
155
+ if run_memory:
156
+ step += 1
157
+ progress.progress(step / total_steps, text=f"Memory benchmark: {cfg.name}...")
158
+ from evals.memory import evaluate_memory
159
+ corpus = build_corpus(corpus_size, ds_configs[0])
160
+ result["memory_mb"] = evaluate_memory(
161
+ cfg.model_id, corpus, batch_size=batch_size, backend=cfg.backend,
162
+ )
163
+
164
+ results.append(result)
165
+
166
+ progress.progress(1.0, text="Done!")
167
+ time.sleep(0.3)
168
+ progress.empty()
169
+
170
+ # Store results in session state
171
+ st.session_state["results"] = results
172
+ st.session_state["selected_datasets"] = selected_datasets
173
+
174
+ # ---------------------------------------------------------------------------
175
+ # Display results
176
+ # ---------------------------------------------------------------------------
177
+ if "results" not in st.session_state:
178
+ st.info("Configure options in the sidebar and hit **Run Benchmark**.")
179
+ st.stop()
180
+
181
+ results = st.session_state["results"]
182
+ selected_datasets = st.session_state["selected_datasets"]
183
+
184
+ # --- Results table ---
185
+ st.header("Results")
186
+ flat_rows = [flatten_result(r) for r in results]
187
+ st.dataframe(flat_rows, use_container_width=True)
188
+
189
+ # --- CSV download ---
190
+ csv_data = results_to_csv(results)
191
+ st.download_button(
192
+ "πŸ“₯ Download CSV",
193
+ data=csv_data,
194
+ file_name="embedding_bench_results.csv",
195
+ mime="text/csv",
196
+ )
197
+
198
+ # --- Charts ---
199
+ st.header("Charts")
200
+ models = [r["name"] for r in results]
201
+
202
+ # Discover datasets
203
+ ds_keys: list[str] = []
204
+ for r in results:
205
+ q = r.get("quality")
206
+ if q:
207
+ ds_keys = list(q.keys())
208
+ break
209
+
210
+ for ds_key in ds_keys:
211
+ first_metrics = None
212
+ for r in results:
213
+ m = r.get("quality", {}).get(ds_key)
214
+ if m:
215
+ first_metrics = m
216
+ break
217
+ if not first_metrics:
218
+ continue
219
+
220
+ if "spearman" in first_metrics:
221
+ values = [r.get("quality", {}).get(ds_key, {}).get("spearman", 0) for r in results]
222
+ fig, ax = plt.subplots(figsize=(max(6, len(models) * 1.5), 4))
223
+ bars = ax.bar(models, values, color="#4C72B0")
224
+ ax.set_ylabel("Spearman Correlation")
225
+ ax.set_title(f"Quality β€” {ds_key}")
226
+ ax.set_ylim(0, 1)
227
+ for bar, v in zip(bars, values):
228
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
229
+ f"{v:.4f}", ha="center", va="bottom", fontsize=9)
230
+ plt.xticks(rotation=30, ha="right")
231
+ plt.tight_layout()
232
+ st.pyplot(fig)
233
+ plt.close(fig)
234
+ else:
235
+ metric_names = ["mrr", "recall@1", "recall@5", "recall@10"]
236
+ x = np.arange(len(models))
237
+ width = 0.18
238
+ colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"]
239
+
240
+ fig, ax = plt.subplots(figsize=(max(8, len(models) * 2.2), 4.5))
241
+ for i, (metric, color) in enumerate(zip(metric_names, colors)):
242
+ values = [r.get("quality", {}).get(ds_key, {}).get(metric, 0) for r in results]
243
+ offset = (i - 1.5) * width
244
+ bars = ax.bar(x + offset, values, width, label=metric, color=color)
245
+ for bar, v in zip(bars, values):
246
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
247
+ f"{v:.2f}", ha="center", va="bottom", fontsize=7)
248
+ ax.set_ylabel("Score")
249
+ ax.set_title(f"Retrieval Quality β€” {ds_key}")
250
+ ax.set_ylim(0, 1.15)
251
+ ax.set_xticks(x)
252
+ ax.set_xticklabels(models, rotation=30, ha="right")
253
+ ax.legend()
254
+ plt.tight_layout()
255
+ st.pyplot(fig)
256
+ plt.close(fig)
257
+
258
+ # Speed chart
259
+ speed_values = [r.get("speed", {}).get("sentences_per_second", 0) for r in results]
260
+ if any(v > 0 for v in speed_values):
261
+ fig, ax = plt.subplots(figsize=(max(6, len(models) * 1.5), 4))
262
+ bars = ax.bar(models, speed_values, color="#55A868")
263
+ ax.set_ylabel("Sentences / second")
264
+ ax.set_title("Encoding Speed")
265
+ for bar, v in zip(bars, speed_values):
266
+ if v > 0:
267
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
268
+ str(v), ha="center", va="bottom", fontsize=9)
269
+ plt.xticks(rotation=30, ha="right")
270
+ plt.tight_layout()
271
+ st.pyplot(fig)
272
+ plt.close(fig)
273
+
274
+ # Memory chart
275
+ mem_values = [r.get("memory_mb", 0) for r in results]
276
+ if any(v > 0 for v in mem_values):
277
+ fig, ax = plt.subplots(figsize=(max(6, len(models) * 1.5), 4))
278
+ bars = ax.bar(models, mem_values, color="#C44E52")
279
+ ax.set_ylabel("Peak Memory (MB)")
280
+ ax.set_title("Memory Usage")
281
+ for bar, v in zip(bars, mem_values):
282
+ if v > 0:
283
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
284
+ str(v), ha="center", va="bottom", fontsize=9)
285
+ plt.xticks(rotation=30, ha="right")
286
+ plt.tight_layout()
287
+ st.pyplot(fig)
288
+ plt.close(fig)
requirements.txt CHANGED
@@ -8,3 +8,4 @@ libembedding
8
  numpy
9
  scipy
10
  matplotlib
 
 
8
  numpy
9
  scipy
10
  matplotlib
11
+ streamlit