AmrYassinIsFree commited on
Commit
bf74331
Β·
1 Parent(s): 7052097

new passage-query dataset style

Browse files
README.md CHANGED
@@ -27,6 +27,8 @@ pip install -r requirements.txt
27
 
28
  ## Usage
29
 
 
 
30
  ```bash
31
  # Full benchmark (quality + speed + memory)
32
  python bench.py
@@ -45,14 +47,89 @@ python bench.py --skip-memory
45
  python bench.py --corpus-size 500 --batch-size 32 --num-runs 5
46
  ```
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ## Metrics
49
 
50
  | Dimension | Metric | Method |
51
  |-----------|--------|--------|
52
- | Quality | Spearman rho | STS Benchmark test set (1,379 pairs) |
 
53
  | Speed | Median encode time | Wall-clock over N runs with warmup |
54
  | Memory | Peak RSS delta | Isolated subprocess via `psutil` |
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  ## Adding a model
57
 
58
  Edit `models.py` and add an entry to `REGISTRY`:
@@ -76,14 +153,15 @@ Edit `models.py` and add an entry to `REGISTRY`:
76
 
77
  ```
78
  embedding-bench/
79
- β”œβ”€β”€ bench.py # CLI entry point
80
- β”œβ”€β”€ models.py # Model registry
81
- β”œβ”€β”€ wrapper.py # Backend wrappers (sbert, fastembed, gguf)
82
- β”œβ”€β”€ corpus.py # Sentence corpus builder
83
- β”œβ”€β”€ report.py # Table formatting
 
84
  β”œβ”€β”€ evals/
85
- β”‚ β”œβ”€β”€ quality.py # STS Benchmark evaluation
86
- β”‚ β”œβ”€β”€ speed.py # Latency measurement
87
- β”‚ └── memory.py # Memory measurement
88
  └── requirements.txt
89
  ```
 
27
 
28
  ## Usage
29
 
30
+ ### Basic
31
+
32
  ```bash
33
  # Full benchmark (quality + speed + memory)
34
  python bench.py
 
47
  python bench.py --corpus-size 500 --batch-size 32 --num-runs 5
48
  ```
49
 
50
+ ### Datasets
51
+
52
+ By default, quality is evaluated on the STS Benchmark. You can evaluate on multiple HuggingFace datasets using built-in presets:
53
+
54
+ | Preset | HF Dataset | Type | Pairs |
55
+ |--------|-----------|------|-------|
56
+ | `sts` | `mteb/stsbenchmark-sts` | Scored (Spearman) | 1,379 |
57
+ | `natural-questions` | `sentence-transformers/natural-questions` | Retrieval (MRR/Recall) | 100,231 |
58
+ | `msmarco` | `sentence-transformers/msmarco-bm25` | Retrieval | 503,000 |
59
+ | `squad` | `sentence-transformers/squad` | Retrieval | 87,599 |
60
+ | `trivia-qa` | `sentence-transformers/trivia-qa` | Retrieval | 73,346 |
61
+ | `gooaq` | `sentence-transformers/gooaq` | Retrieval | 3,012,496 |
62
+ | `hotpotqa` | `sentence-transformers/hotpotqa` | Retrieval | 84,500 |
63
+
64
+ ```bash
65
+ # Evaluate on multiple datasets
66
+ python bench.py --models mpnet bge-small \
67
+ --datasets sts natural-questions squad \
68
+ --skip-speed --skip-memory
69
+
70
+ # Limit pairs for large datasets
71
+ python bench.py --datasets msmarco gooaq --max-pairs 1000
72
+
73
+ # Use a custom HF dataset (overrides --datasets)
74
+ python bench.py --dataset my-org/my-pairs \
75
+ --query-col query --passage-col passage --score-col none
76
+ ```
77
+
78
+ Scored datasets (with `--score-col`) report **Spearman correlation**. Pair-only datasets (`--score-col none`) report **MRR**, **Recall@1**, **Recall@5**, and **Recall@10**.
79
+
80
+ ### Export results
81
+
82
+ ```bash
83
+ # Export to CSV
84
+ python bench.py --csv results.csv
85
+
86
+ # Save charts as PNG
87
+ python bench.py --charts ./results
88
+
89
+ # Both
90
+ python bench.py --models mpnet bge-small \
91
+ --datasets sts squad natural-questions \
92
+ --max-pairs 1000 \
93
+ --csv results.csv --charts ./results
94
+ ```
95
+
96
+ Charts generated:
97
+ - `quality_<dataset>.png` β€” Spearman bar chart (scored) or grouped MRR/Recall bars (retrieval)
98
+ - `speed.png` β€” sentences/second comparison
99
+ - `memory.png` β€” peak memory usage comparison
100
+
101
  ## Metrics
102
 
103
  | Dimension | Metric | Method |
104
  |-----------|--------|--------|
105
+ | Quality (scored) | Spearman rho | Cosine similarity vs gold scores |
106
+ | Quality (pairs) | MRR, Recall@k | Retrieval ranking of positive passages |
107
  | Speed | Median encode time | Wall-clock over N runs with warmup |
108
  | Memory | Peak RSS delta | Isolated subprocess via `psutil` |
109
 
110
+ ## CLI reference
111
+
112
+ ```
113
+ --models Models to benchmark (default: all)
114
+ --corpus-size Sentences for speed/memory tests (default: 1000)
115
+ --batch-size Encoding batch size (default: 64)
116
+ --num-runs Speed benchmark runs (default: 3)
117
+ --skip-quality Skip quality evaluation
118
+ --skip-speed Skip speed measurement
119
+ --skip-memory Skip memory measurement
120
+ --datasets Dataset presets (default: sts)
121
+ --max-pairs Limit pairs per dataset
122
+ --dataset Custom HF dataset (overrides --datasets)
123
+ --config Dataset config/subset name (e.g. 'triplet')
124
+ --split Dataset split (default: test)
125
+ --query-col Query column name (default: sentence1)
126
+ --passage-col Passage column name (default: sentence2)
127
+ --score-col Score column (default: score, 'none' for pairs)
128
+ --score-scale Score normalization divisor (default: 5.0)
129
+ --csv Export results to CSV
130
+ --charts Save charts to directory
131
+ ```
132
+
133
  ## Adding a model
134
 
135
  Edit `models.py` and add an entry to `REGISTRY`:
 
153
 
154
  ```
155
  embedding-bench/
156
+ β”œβ”€β”€ bench.py # CLI entry point
157
+ β”œβ”€β”€ models.py # Model registry
158
+ β”œβ”€β”€ wrapper.py # Backend wrappers (sbert, fastembed, gguf)
159
+ β”œβ”€β”€ corpus.py # Sentence corpus builder
160
+ β”œβ”€β”€ dataset_config.py # Dataset presets and configuration
161
+ β”œβ”€β”€ report.py # Table formatting, CSV export, charts
162
  β”œβ”€β”€ evals/
163
+ β”‚ β”œβ”€β”€ quality.py # STS + retrieval evaluation
164
+ β”‚ β”œβ”€β”€ speed.py # Latency measurement
165
+ β”‚ └── memory.py # Memory measurement
166
  └── requirements.txt
167
  ```
bench.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
  import argparse
4
 
5
  from corpus import build_corpus
 
6
  from evals import evaluate_memory, evaluate_quality, evaluate_speed
7
  from models import REGISTRY
8
  from report import print_report
@@ -28,15 +29,61 @@ def main(argv: list[str] | None = None) -> None:
28
  parser.add_argument("--skip-speed", action="store_true")
29
  parser.add_argument("--skip-memory", action="store_true")
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  args = parser.parse_args(argv)
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  configs = [REGISTRY[k] for k in args.models]
34
  baseline_name = next((c.name for c in configs if c.is_baseline), None)
35
 
 
36
  corpus: list[str] | None = None
37
  if not args.skip_speed or not args.skip_memory:
38
  print(f"Preparing corpus ({args.corpus_size} sentences)...")
39
- corpus = build_corpus(args.corpus_size)
40
 
41
  results = []
42
  for cfg in configs:
@@ -47,10 +94,16 @@ def main(argv: list[str] | None = None) -> None:
47
  result: dict = {"name": cfg.name, "is_baseline": cfg.is_baseline}
48
 
49
  if not args.skip_quality:
50
- print(" Evaluating quality (STS Benchmark)...")
51
  model = load_model(cfg)
52
- result["quality"] = evaluate_quality(model)
53
- print(f" Quality: {result['quality']:.4f}")
 
 
 
 
 
 
 
54
  del model
55
 
56
  if not args.skip_speed and corpus is not None:
@@ -67,7 +120,8 @@ def main(argv: list[str] | None = None) -> None:
67
 
68
  results.append(result)
69
 
70
- print_report(results, baseline_name=baseline_name)
 
71
 
72
 
73
  if __name__ == "__main__":
 
3
  import argparse
4
 
5
  from corpus import build_corpus
6
+ from dataset_config import DATASET_PRESETS, DatasetConfig
7
  from evals import evaluate_memory, evaluate_quality, evaluate_speed
8
  from models import REGISTRY
9
  from report import print_report
 
29
  parser.add_argument("--skip-speed", action="store_true")
30
  parser.add_argument("--skip-memory", action="store_true")
31
 
32
+ # Dataset configuration
33
+ parser.add_argument(
34
+ "--datasets",
35
+ nargs="+",
36
+ default=["sts"],
37
+ choices=list(DATASET_PRESETS.keys()),
38
+ help=f"Dataset presets to evaluate (default: sts). "
39
+ f"Available: {', '.join(DATASET_PRESETS.keys())}",
40
+ )
41
+ parser.add_argument("--max-pairs", type=int, default=None,
42
+ help="Limit number of pairs per dataset (useful for large datasets)")
43
+
44
+ # Custom dataset (overrides --datasets)
45
+ parser.add_argument("--dataset", default=None,
46
+ help="Custom HF dataset name (overrides --datasets)")
47
+ parser.add_argument("--config", default=None,
48
+ help="Dataset config/subset name (e.g. 'triplet')")
49
+ parser.add_argument("--split", default="test")
50
+ parser.add_argument("--query-col", default="sentence1")
51
+ parser.add_argument("--passage-col", default="sentence2")
52
+ parser.add_argument("--score-col", default="score",
53
+ help="Score column name. Pass 'none' for pair-only datasets.")
54
+ parser.add_argument("--score-scale", type=float, default=5.0)
55
+
56
+ # Output options
57
+ parser.add_argument("--csv", default=None, metavar="PATH",
58
+ help="Export results to a CSV file")
59
+ parser.add_argument("--charts", default=None, metavar="DIR",
60
+ help="Save charts to a directory (e.g. ./results)")
61
+
62
  args = parser.parse_args(argv)
63
 
64
+ # Build list of dataset configs
65
+ if args.dataset:
66
+ # Custom dataset overrides presets
67
+ ds_configs = [DatasetConfig(
68
+ name=args.dataset,
69
+ config=args.config,
70
+ split=args.split,
71
+ query_col=args.query_col,
72
+ passage_col=args.passage_col,
73
+ score_col=None if args.score_col.lower() == "none" else args.score_col,
74
+ score_scale=args.score_scale,
75
+ )]
76
+ else:
77
+ ds_configs = [DATASET_PRESETS[k] for k in args.datasets]
78
+
79
  configs = [REGISTRY[k] for k in args.models]
80
  baseline_name = next((c.name for c in configs if c.is_baseline), None)
81
 
82
+ # Use first dataset for corpus building
83
  corpus: list[str] | None = None
84
  if not args.skip_speed or not args.skip_memory:
85
  print(f"Preparing corpus ({args.corpus_size} sentences)...")
86
+ corpus = build_corpus(args.corpus_size, ds_configs[0])
87
 
88
  results = []
89
  for cfg in configs:
 
94
  result: dict = {"name": cfg.name, "is_baseline": cfg.is_baseline}
95
 
96
  if not args.skip_quality:
 
97
  model = load_model(cfg)
98
+ quality_results = {}
99
+ for ds_cfg in ds_configs:
100
+ ds_key = ds_cfg.name.split("/")[-1]
101
+ print(f" Evaluating quality on {ds_cfg.name}...")
102
+ quality_results[ds_key] = evaluate_quality(
103
+ model, ds_cfg, max_pairs=args.max_pairs,
104
+ )
105
+ print(f" {quality_results[ds_key]}")
106
+ result["quality"] = quality_results
107
  del model
108
 
109
  if not args.skip_speed and corpus is not None:
 
120
 
121
  results.append(result)
122
 
123
+ print_report(results, baseline_name=baseline_name,
124
+ csv_path=args.csv, chart_dir=args.charts)
125
 
126
 
127
  if __name__ == "__main__":
corpus.py CHANGED
@@ -2,11 +2,15 @@ from __future__ import annotations
2
 
3
  from datasets import load_dataset
4
 
 
5
 
6
- def build_corpus(size: int) -> list[str]:
7
- """Build a corpus of real sentences from the STS Benchmark dataset."""
8
- dataset = load_dataset("mteb/stsbenchmark-sts", split="test")
9
- sentences = list(dataset["sentence1"]) + list(dataset["sentence2"])
 
 
 
10
  full: list[str] = []
11
  while len(full) < size:
12
  full.extend(sentences)
 
2
 
3
  from datasets import load_dataset
4
 
5
+ from dataset_config import DatasetConfig
6
 
7
+
8
+ def build_corpus(size: int, ds_cfg: DatasetConfig | None = None) -> list[str]:
9
+ """Build a corpus of real sentences from the configured dataset."""
10
+ if ds_cfg is None:
11
+ ds_cfg = DatasetConfig()
12
+ dataset = load_dataset(ds_cfg.name, ds_cfg.config, split=ds_cfg.split)
13
+ sentences = list(dataset[ds_cfg.query_col]) + list(dataset[ds_cfg.passage_col])
14
  full: list[str] = []
15
  while len(full) < size:
16
  full.extend(sentences)
dataset_config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class DatasetConfig:
8
+ """Configuration for the evaluation dataset."""
9
+
10
+ name: str = "mteb/stsbenchmark-sts"
11
+ config: str | None = None
12
+ split: str = "test"
13
+ query_col: str = "sentence1"
14
+ passage_col: str = "sentence2"
15
+ score_col: str | None = "score"
16
+ score_scale: float = 5.0
17
+
18
+
19
+ DATASET_PRESETS: dict[str, DatasetConfig] = {
20
+ "sts": DatasetConfig(
21
+ name="mteb/stsbenchmark-sts",
22
+ split="test",
23
+ query_col="sentence1",
24
+ passage_col="sentence2",
25
+ score_col="score",
26
+ score_scale=5.0,
27
+ ),
28
+ "natural-questions": DatasetConfig(
29
+ name="sentence-transformers/natural-questions",
30
+ split="train",
31
+ query_col="query",
32
+ passage_col="answer",
33
+ score_col=None,
34
+ ),
35
+ "msmarco": DatasetConfig(
36
+ name="sentence-transformers/msmarco-bm25",
37
+ config="triplet",
38
+ split="train",
39
+ query_col="query",
40
+ passage_col="positive",
41
+ score_col=None,
42
+ ),
43
+ "squad": DatasetConfig(
44
+ name="sentence-transformers/squad",
45
+ split="train",
46
+ query_col="question",
47
+ passage_col="answer",
48
+ score_col=None,
49
+ ),
50
+ "trivia-qa": DatasetConfig(
51
+ name="sentence-transformers/trivia-qa",
52
+ split="train",
53
+ query_col="query",
54
+ passage_col="answer",
55
+ score_col=None,
56
+ ),
57
+ "gooaq": DatasetConfig(
58
+ name="sentence-transformers/gooaq",
59
+ split="train",
60
+ query_col="question",
61
+ passage_col="answer",
62
+ score_col=None,
63
+ ),
64
+ "hotpotqa": DatasetConfig(
65
+ name="sentence-transformers/hotpotqa",
66
+ config="triplet",
67
+ split="train",
68
+ query_col="anchor",
69
+ passage_col="positive",
70
+ score_col=None,
71
+ ),
72
+ }
evals/quality.py CHANGED
@@ -4,21 +4,79 @@ import numpy as np
4
  from datasets import load_dataset
5
  from scipy.stats import spearmanr
6
 
 
7
 
8
- def evaluate_quality(model) -> float:
9
- """Return Spearman correlation on the STS Benchmark test set."""
10
- dataset = load_dataset("mteb/stsbenchmark-sts", split="test")
11
- sentences1 = list(dataset["sentence1"])
12
- sentences2 = list(dataset["sentence2"])
13
- gold_scores = [s / 5.0 for s in dataset["score"]]
14
 
15
- emb1 = model.encode(sentences1)
16
- emb2 = model.encode(sentences2)
 
 
17
 
18
- # Row-wise cosine similarity
19
- cos_sims = np.sum(emb1 * emb2, axis=1) / (
20
- np.linalg.norm(emb1, axis=1) * np.linalg.norm(emb2, axis=1)
21
- )
22
 
23
- correlation, _ = spearmanr(cos_sims, gold_scores)
24
- return correlation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from datasets import load_dataset
5
  from scipy.stats import spearmanr
6
 
7
+ from dataset_config import DatasetConfig
8
 
 
 
 
 
 
 
9
 
10
+ def _normalize(emb: np.ndarray) -> np.ndarray:
11
+ """L2-normalize each row."""
12
+ norms = np.linalg.norm(emb, axis=1, keepdims=True)
13
+ return emb / norms
14
 
 
 
 
 
15
 
16
+ def _retrieval_metrics(emb_q: np.ndarray, emb_p: np.ndarray) -> dict[str, float]:
17
+ """Compute MRR and Recall@k assuming query i matches passage i."""
18
+ emb_q = _normalize(emb_q)
19
+ emb_p = _normalize(emb_p)
20
+
21
+ # Similarity matrix: (num_queries, num_passages)
22
+ sims = emb_q @ emb_p.T
23
+
24
+ n = sims.shape[0]
25
+ # For each query, rank passages by descending similarity
26
+ # ranks[i] = rank of the correct passage (0-indexed)
27
+ sorted_indices = np.argsort(-sims, axis=1)
28
+ ranks = np.array([int(np.where(sorted_indices[i] == i)[0][0]) for i in range(n)])
29
+
30
+ mrr = float(np.mean(1.0 / (ranks + 1)))
31
+ recall_1 = float(np.mean(ranks < 1))
32
+ recall_5 = float(np.mean(ranks < 5))
33
+ recall_10 = float(np.mean(ranks < 10))
34
+
35
+ return {
36
+ "mrr": round(mrr, 4),
37
+ "recall@1": round(recall_1, 4),
38
+ "recall@5": round(recall_5, 4),
39
+ "recall@10": round(recall_10, 4),
40
+ }
41
+
42
+
43
+ def evaluate_quality(
44
+ model,
45
+ ds_cfg: DatasetConfig | None = None,
46
+ max_pairs: int | None = None,
47
+ ) -> dict[str, float]:
48
+ """Evaluate embedding quality on a dataset.
49
+
50
+ Returns a dict with either {"spearman": float} for scored datasets
51
+ or {"mrr", "recall@1", "recall@5", "recall@10"} for pair datasets.
52
+ """
53
+ if ds_cfg is None:
54
+ ds_cfg = DatasetConfig()
55
+
56
+ dataset = load_dataset(ds_cfg.name, ds_cfg.config, split=ds_cfg.split)
57
+ queries = list(dataset[ds_cfg.query_col])
58
+ passages = list(dataset[ds_cfg.passage_col])
59
+
60
+ if max_pairs is not None and len(queries) > max_pairs:
61
+ queries = queries[:max_pairs]
62
+ passages = passages[:max_pairs]
63
+
64
+ emb_q = model.encode(queries)
65
+ emb_p = model.encode(passages)
66
+
67
+ if ds_cfg.score_col is not None:
68
+ # Scored mode: Spearman correlation
69
+ scores = list(dataset[ds_cfg.score_col])
70
+ if max_pairs is not None and len(scores) > max_pairs:
71
+ scores = scores[:max_pairs]
72
+ gold_scores = [s / ds_cfg.score_scale for s in scores]
73
+
74
+ cos_sims = np.sum(emb_q * emb_p, axis=1) / (
75
+ np.linalg.norm(emb_q, axis=1) * np.linalg.norm(emb_p, axis=1)
76
+ )
77
+
78
+ correlation, _ = spearmanr(cos_sims, gold_scores)
79
+ return {"spearman": round(float(correlation), 4)}
80
+
81
+ # Pair mode: retrieval metrics
82
+ return _retrieval_metrics(emb_q, emb_p)
models.py CHANGED
@@ -22,24 +22,24 @@ REGISTRY: dict[str, ModelConfig] = {
22
  name="bge-small-en-v1.5",
23
  model_id="BAAI/bge-small-en-v1.5",
24
  ),
25
- "bge-small-fe": ModelConfig(
26
- name="bge-small-en-v1.5 (fastembed)",
27
- model_id="BAAI/bge-small-en-v1.5",
28
- backend="fastembed",
29
- ),
30
- "all-minilm-fe": ModelConfig(
31
- name="all-MiniLM-L6-v2 (fastembed)",
32
- model_id="sentence-transformers/all-MiniLM-L6-v2",
33
- backend="fastembed",
34
- ),
35
- "bge-small-le": ModelConfig(
36
- name="bge-small-en-v1.5 (libembedding)",
37
- model_id="BAAI/bge-small-en-v1.5",
38
- backend="libembedding",
39
- ),
40
- "all-minilm-le": ModelConfig(
41
- name="all-MiniLM-L6-v2 (libembedding)",
42
- model_id="sentence-transformers/all-MiniLM-L6-v2",
43
- backend="libembedding",
44
- ),
45
  }
 
22
  name="bge-small-en-v1.5",
23
  model_id="BAAI/bge-small-en-v1.5",
24
  ),
25
+ # "bge-small-fe": ModelConfig(
26
+ # name="bge-small-en-v1.5 (fastembed)",
27
+ # model_id="BAAI/bge-small-en-v1.5",
28
+ # backend="fastembed",
29
+ # ),
30
+ # "all-minilm-fe": ModelConfig(
31
+ # name="all-MiniLM-L6-v2 (fastembed)",
32
+ # model_id="sentence-transformers/all-MiniLM-L6-v2",
33
+ # backend="fastembed",
34
+ # ),
35
+ # "bge-small-le": ModelConfig(
36
+ # name="bge-small-en-v1.5 (libembedding)",
37
+ # model_id="BAAI/bge-small-en-v1.5",
38
+ # backend="libembedding",
39
+ # ),
40
+ # "all-minilm-le": ModelConfig(
41
+ # name="all-MiniLM-L6-v2 (libembedding)",
42
+ # model_id="sentence-transformers/all-MiniLM-L6-v2",
43
+ # backend="libembedding",
44
+ # ),
45
  }
report.py CHANGED
@@ -1,13 +1,179 @@
1
  from __future__ import annotations
2
 
 
 
 
3
  from typing import Any, Optional
4
 
 
 
5
  from tabulate import tabulate
6
 
7
 
8
- def print_report(results: list[dict[str, Any]], baseline_name: Optional[str] = None) -> None:
9
- """Print a formatted comparison table to stdout."""
10
- headers = ["Model", "Quality (STS)", "Speed (sent/s)", "Median Time (s)", "Memory (MB)"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  rows: list[list[Any]] = []
12
 
13
  for r in results:
@@ -15,20 +181,29 @@ def print_report(results: list[dict[str, Any]], baseline_name: Optional[str] = N
15
  if r.get("is_baseline"):
16
  name += " [B]"
17
 
18
- quality = r.get("quality")
19
  speed = r.get("speed")
20
  memory = r.get("memory_mb")
21
 
22
- rows.append([
23
- name,
24
- f"{quality:.4f}" if quality is not None else "β€”",
 
 
25
  f"{speed['sentences_per_second']}" if speed else "β€”",
26
  f"{speed['median_seconds']}" if speed else "β€”",
27
  f"{memory}" if memory is not None else "β€”",
28
  ])
 
29
 
30
  print()
31
  print(tabulate(rows, headers=headers, tablefmt="simple"))
32
  if baseline_name:
33
  print(f"\n[B] = baseline ({baseline_name})")
34
  print()
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import csv
4
+ import os
5
+ from pathlib import Path
6
  from typing import Any, Optional
7
 
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
  from tabulate import tabulate
11
 
12
 
13
+ def _format_metrics(metrics: dict[str, float]) -> str:
14
+ """Format a single dataset's metrics into a compact string."""
15
+ if "spearman" in metrics:
16
+ return f"{metrics['spearman']:.4f}"
17
+ if "mrr" in metrics:
18
+ return f"MRR={metrics['mrr']:.4f} R@1={metrics['recall@1']:.4f}"
19
+ return "β€”"
20
+
21
+
22
+ def _flatten_result(r: dict[str, Any]) -> dict[str, Any]:
23
+ """Flatten a single result dict into a flat key-value dict for CSV."""
24
+ flat: dict[str, Any] = {"model": r["name"]}
25
+
26
+ for ds_key, metrics in r.get("quality", {}).items():
27
+ for metric_name, value in metrics.items():
28
+ flat[f"{ds_key}/{metric_name}"] = value
29
+
30
+ speed = r.get("speed")
31
+ if speed:
32
+ flat["speed_sent_per_s"] = speed["sentences_per_second"]
33
+ flat["median_time_s"] = speed["median_seconds"]
34
+
35
+ memory = r.get("memory_mb")
36
+ if memory is not None:
37
+ flat["memory_mb"] = memory
38
+
39
+ return flat
40
+
41
+
42
+ def export_csv(results: list[dict[str, Any]], path: str) -> None:
43
+ """Export results to a CSV file."""
44
+ rows = [_flatten_result(r) for r in results]
45
+ fieldnames = list(rows[0].keys())
46
+ # Ensure all fields are captured
47
+ for row in rows[1:]:
48
+ for k in row:
49
+ if k not in fieldnames:
50
+ fieldnames.append(k)
51
+
52
+ with open(path, "w", newline="") as f:
53
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
54
+ writer.writeheader()
55
+ writer.writerows(rows)
56
+ print(f"CSV saved to {path}")
57
+
58
+
59
+ def plot_charts(results: list[dict[str, Any]], output_dir: str) -> None:
60
+ """Generate and save benchmark charts."""
61
+ os.makedirs(output_dir, exist_ok=True)
62
+ models = [r["name"] for r in results]
63
+
64
+ # --- Quality charts (one per dataset) ---
65
+ ds_keys: list[str] = []
66
+ for r in results:
67
+ quality = r.get("quality")
68
+ if quality:
69
+ ds_keys = list(quality.keys())
70
+ break
71
+
72
+ for ds_key in ds_keys:
73
+ first_metrics = None
74
+ for r in results:
75
+ m = r.get("quality", {}).get(ds_key)
76
+ if m:
77
+ first_metrics = m
78
+ break
79
+ if not first_metrics:
80
+ continue
81
+
82
+ if "spearman" in first_metrics:
83
+ # Single bar chart for spearman
84
+ values = [r.get("quality", {}).get(ds_key, {}).get("spearman", 0) for r in results]
85
+ fig, ax = plt.subplots(figsize=(max(6, len(models) * 1.2), 5))
86
+ bars = ax.bar(models, values, color="#4C72B0")
87
+ ax.set_ylabel("Spearman Correlation")
88
+ ax.set_title(f"Quality β€” {ds_key}")
89
+ ax.set_ylim(0, 1)
90
+ for bar, v in zip(bars, values):
91
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
92
+ f"{v:.4f}", ha="center", va="bottom", fontsize=9)
93
+ plt.xticks(rotation=30, ha="right")
94
+ plt.tight_layout()
95
+ fig.savefig(os.path.join(output_dir, f"quality_{ds_key}.png"), dpi=150)
96
+ plt.close(fig)
97
+ else:
98
+ # Grouped bar chart for retrieval metrics
99
+ metric_names = ["mrr", "recall@1", "recall@5", "recall@10"]
100
+ x = np.arange(len(models))
101
+ width = 0.18
102
+ colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"]
103
+
104
+ fig, ax = plt.subplots(figsize=(max(8, len(models) * 2), 5))
105
+ for i, (metric, color) in enumerate(zip(metric_names, colors)):
106
+ values = [r.get("quality", {}).get(ds_key, {}).get(metric, 0) for r in results]
107
+ offset = (i - 1.5) * width
108
+ bars = ax.bar(x + offset, values, width, label=metric, color=color)
109
+ for bar, v in zip(bars, values):
110
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
111
+ f"{v:.2f}", ha="center", va="bottom", fontsize=7)
112
+ ax.set_ylabel("Score")
113
+ ax.set_title(f"Retrieval Quality β€” {ds_key}")
114
+ ax.set_ylim(0, 1.15)
115
+ ax.set_xticks(x)
116
+ ax.set_xticklabels(models, rotation=30, ha="right")
117
+ ax.legend()
118
+ plt.tight_layout()
119
+ fig.savefig(os.path.join(output_dir, f"quality_{ds_key}.png"), dpi=150)
120
+ plt.close(fig)
121
+
122
+ # --- Speed chart ---
123
+ speed_values = [r.get("speed", {}).get("sentences_per_second", 0) for r in results]
124
+ if any(v > 0 for v in speed_values):
125
+ fig, ax = plt.subplots(figsize=(max(6, len(models) * 1.2), 5))
126
+ bars = ax.bar(models, speed_values, color="#55A868")
127
+ ax.set_ylabel("Sentences / second")
128
+ ax.set_title("Encoding Speed")
129
+ for bar, v in zip(bars, speed_values):
130
+ if v > 0:
131
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
132
+ str(v), ha="center", va="bottom", fontsize=9)
133
+ plt.xticks(rotation=30, ha="right")
134
+ plt.tight_layout()
135
+ fig.savefig(os.path.join(output_dir, "speed.png"), dpi=150)
136
+ plt.close(fig)
137
+
138
+ # --- Memory chart ---
139
+ mem_values = [r.get("memory_mb", 0) for r in results]
140
+ if any(v > 0 for v in mem_values):
141
+ fig, ax = plt.subplots(figsize=(max(6, len(models) * 1.2), 5))
142
+ bars = ax.bar(models, mem_values, color="#C44E52")
143
+ ax.set_ylabel("Peak Memory (MB)")
144
+ ax.set_title("Memory Usage")
145
+ for bar, v in zip(bars, mem_values):
146
+ if v > 0:
147
+ ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
148
+ str(v), ha="center", va="bottom", fontsize=9)
149
+ plt.xticks(rotation=30, ha="right")
150
+ plt.tight_layout()
151
+ fig.savefig(os.path.join(output_dir, "memory.png"), dpi=150)
152
+ plt.close(fig)
153
+
154
+ print(f"Charts saved to {output_dir}/")
155
+
156
+
157
+ def print_report(
158
+ results: list[dict[str, Any]],
159
+ baseline_name: Optional[str] = None,
160
+ csv_path: Optional[str] = None,
161
+ chart_dir: Optional[str] = None,
162
+ ) -> None:
163
+ """Print a formatted comparison table and optionally export CSV/charts."""
164
+ # Discover dataset columns from the first result that has quality data
165
+ ds_keys: list[str] = []
166
+ for r in results:
167
+ quality = r.get("quality")
168
+ if quality:
169
+ ds_keys = list(quality.keys())
170
+ break
171
+
172
+ headers = ["Model"]
173
+ for ds_key in ds_keys:
174
+ headers.append(f"Quality ({ds_key})")
175
+ headers.extend(["Speed (sent/s)", "Median Time (s)", "Memory (MB)"])
176
+
177
  rows: list[list[Any]] = []
178
 
179
  for r in results:
 
181
  if r.get("is_baseline"):
182
  name += " [B]"
183
 
184
+ quality = r.get("quality", {})
185
  speed = r.get("speed")
186
  memory = r.get("memory_mb")
187
 
188
+ row: list[Any] = [name]
189
+ for ds_key in ds_keys:
190
+ metrics = quality.get(ds_key)
191
+ row.append(_format_metrics(metrics) if metrics else "β€”")
192
+ row.extend([
193
  f"{speed['sentences_per_second']}" if speed else "β€”",
194
  f"{speed['median_seconds']}" if speed else "β€”",
195
  f"{memory}" if memory is not None else "β€”",
196
  ])
197
+ rows.append(row)
198
 
199
  print()
200
  print(tabulate(rows, headers=headers, tablefmt="simple"))
201
  if baseline_name:
202
  print(f"\n[B] = baseline ({baseline_name})")
203
  print()
204
+
205
+ if csv_path:
206
+ export_csv(results, csv_path)
207
+
208
+ if chart_dir:
209
+ plot_charts(results, chart_dir)
requirements.txt CHANGED
@@ -7,3 +7,4 @@ fastembed
7
  libembedding
8
  numpy
9
  scipy
 
 
7
  libembedding
8
  numpy
9
  scipy
10
+ matplotlib
results.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ model,stsbenchmark-sts/spearman,natural-questions/mrr,natural-questions/recall@1,natural-questions/recall@5,natural-questions/recall@10,squad/mrr,squad/recall@1,squad/recall@5,squad/recall@10,speed_sent_per_s,median_time_s,memory_mb
2
+ all-mpnet-base-v2,0.8519,0.9762,0.955,0.999,1.0,0.2282,0.075,0.405,0.627,775.4,1.2897,369.4
3
+ bge-small-en-v1.5,0.8615,0.9557,0.927,0.988,0.995,0.2257,0.081,0.382,0.614,2144.4,0.4663,353.2
results/memory.png ADDED
results/quality_natural-questions.png ADDED
results/quality_squad.png ADDED
results/quality_stsbenchmark-sts.png ADDED
results/speed.png ADDED