embedding-bench / wrapper.py
AmrYassinIsFree
adding more models, fix prefixs
7c4c461
from __future__ import annotations
import numpy as np
from models import ModelConfig
def _apply_prefix(cfg: ModelConfig, sentences: list[str], is_query: bool) -> list[str]:
prefix = cfg.query_prefix if is_query else cfg.passage_prefix
if not prefix:
return sentences
return [prefix + s for s in sentences]
class SBertWrapper:
"""Wraps sentence_transformers.SentenceTransformer."""
def __init__(self, cfg: ModelConfig):
from sentence_transformers import SentenceTransformer
self._cfg = cfg
load_kwargs: dict = {}
if cfg.trust_remote_code:
load_kwargs["trust_remote_code"] = True
self._model = SentenceTransformer(cfg.model_id, **load_kwargs)
def encode(self, sentences: list[str], batch_size: int = 64, is_query: bool = False, **kwargs) -> np.ndarray:
kwargs.setdefault("show_progress_bar", False)
prompt_name = self._cfg.query_prompt_name if is_query else self._cfg.passage_prompt_name
if prompt_name and "prompt_name" not in kwargs:
kwargs["prompt_name"] = prompt_name
else:
sentences = _apply_prefix(self._cfg, sentences, is_query)
return self._model.encode(sentences, batch_size=batch_size, **kwargs)
class GGUFWrapper:
"""Wraps llama_cpp.Llama in embedding mode."""
def __init__(self, cfg: ModelConfig):
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
self._cfg = cfg
path = hf_hub_download(repo_id=cfg.model_id, filename=cfg.gguf_file)
self._model = Llama(
model_path=path, embedding=True, n_ctx=512, verbose=False
)
def encode(self, sentences: list[str], batch_size: int = 64, is_query: bool = False, **kwargs) -> np.ndarray:
sentences = _apply_prefix(self._cfg, sentences, is_query)
all_embeddings = []
for i in range(0, len(sentences), batch_size):
batch = sentences[i : i + batch_size]
response = self._model.create_embedding(batch)
embeddings = [item["embedding"] for item in response["data"]]
all_embeddings.extend(embeddings)
return np.array(all_embeddings, dtype=np.float32)
class FastEmbedWrapper:
"""Wraps fastembed.TextEmbedding."""
def __init__(self, cfg: ModelConfig):
from fastembed import TextEmbedding
self._cfg = cfg
self._model = TextEmbedding(model_name=cfg.model_id)
def encode(self, sentences: list[str], batch_size: int = 64, is_query: bool = False, **kwargs) -> np.ndarray:
sentences = _apply_prefix(self._cfg, sentences, is_query)
embeddings = list(self._model.embed(sentences, batch_size=batch_size))
return np.array(embeddings, dtype=np.float32)
class LibEmbedWrapper:
"""Wraps libembedding.TextEmbedding."""
def __init__(self, cfg: ModelConfig):
from libembedding import TextEmbedding
self._cfg = cfg
self._model = TextEmbedding(cfg.model_id)
def encode(self, sentences: list[str], batch_size: int = 64, is_query: bool = False, **kwargs) -> np.ndarray:
sentences = _apply_prefix(self._cfg, sentences, is_query)
embeddings = list(self._model.embed(sentences, batch_size=batch_size))
return np.array(embeddings, dtype=np.float32)
def load_model(cfg: ModelConfig) -> SBertWrapper | GGUFWrapper | FastEmbedWrapper | LibEmbedWrapper:
"""Factory: returns the right wrapper for the model's backend."""
if cfg.backend == "gguf":
return GGUFWrapper(cfg)
if cfg.backend == "fastembed":
return FastEmbedWrapper(cfg)
if cfg.backend == "libembedding":
return LibEmbedWrapper(cfg)
return SBertWrapper(cfg)