Spaces:
Running
Running
| from __future__ import annotations | |
| import numpy as np | |
| from models import ModelConfig | |
| def _apply_prefix(cfg: ModelConfig, sentences: list[str], is_query: bool) -> list[str]: | |
| prefix = cfg.query_prefix if is_query else cfg.passage_prefix | |
| if not prefix: | |
| return sentences | |
| return [prefix + s for s in sentences] | |
| class SBertWrapper: | |
| """Wraps sentence_transformers.SentenceTransformer.""" | |
| def __init__(self, cfg: ModelConfig): | |
| from sentence_transformers import SentenceTransformer | |
| self._cfg = cfg | |
| load_kwargs: dict = {} | |
| if cfg.trust_remote_code: | |
| load_kwargs["trust_remote_code"] = True | |
| self._model = SentenceTransformer(cfg.model_id, **load_kwargs) | |
| def encode(self, sentences: list[str], batch_size: int = 64, is_query: bool = False, **kwargs) -> np.ndarray: | |
| kwargs.setdefault("show_progress_bar", False) | |
| prompt_name = self._cfg.query_prompt_name if is_query else self._cfg.passage_prompt_name | |
| if prompt_name and "prompt_name" not in kwargs: | |
| kwargs["prompt_name"] = prompt_name | |
| else: | |
| sentences = _apply_prefix(self._cfg, sentences, is_query) | |
| return self._model.encode(sentences, batch_size=batch_size, **kwargs) | |
| class GGUFWrapper: | |
| """Wraps llama_cpp.Llama in embedding mode.""" | |
| def __init__(self, cfg: ModelConfig): | |
| from huggingface_hub import hf_hub_download | |
| from llama_cpp import Llama | |
| self._cfg = cfg | |
| path = hf_hub_download(repo_id=cfg.model_id, filename=cfg.gguf_file) | |
| self._model = Llama( | |
| model_path=path, embedding=True, n_ctx=512, verbose=False | |
| ) | |
| def encode(self, sentences: list[str], batch_size: int = 64, is_query: bool = False, **kwargs) -> np.ndarray: | |
| sentences = _apply_prefix(self._cfg, sentences, is_query) | |
| all_embeddings = [] | |
| for i in range(0, len(sentences), batch_size): | |
| batch = sentences[i : i + batch_size] | |
| response = self._model.create_embedding(batch) | |
| embeddings = [item["embedding"] for item in response["data"]] | |
| all_embeddings.extend(embeddings) | |
| return np.array(all_embeddings, dtype=np.float32) | |
| class FastEmbedWrapper: | |
| """Wraps fastembed.TextEmbedding.""" | |
| def __init__(self, cfg: ModelConfig): | |
| from fastembed import TextEmbedding | |
| self._cfg = cfg | |
| self._model = TextEmbedding(model_name=cfg.model_id) | |
| def encode(self, sentences: list[str], batch_size: int = 64, is_query: bool = False, **kwargs) -> np.ndarray: | |
| sentences = _apply_prefix(self._cfg, sentences, is_query) | |
| embeddings = list(self._model.embed(sentences, batch_size=batch_size)) | |
| return np.array(embeddings, dtype=np.float32) | |
| class LibEmbedWrapper: | |
| """Wraps libembedding.TextEmbedding.""" | |
| def __init__(self, cfg: ModelConfig): | |
| from libembedding import TextEmbedding | |
| self._cfg = cfg | |
| self._model = TextEmbedding(cfg.model_id) | |
| def encode(self, sentences: list[str], batch_size: int = 64, is_query: bool = False, **kwargs) -> np.ndarray: | |
| sentences = _apply_prefix(self._cfg, sentences, is_query) | |
| embeddings = list(self._model.embed(sentences, batch_size=batch_size)) | |
| return np.array(embeddings, dtype=np.float32) | |
| def load_model(cfg: ModelConfig) -> SBertWrapper | GGUFWrapper | FastEmbedWrapper | LibEmbedWrapper: | |
| """Factory: returns the right wrapper for the model's backend.""" | |
| if cfg.backend == "gguf": | |
| return GGUFWrapper(cfg) | |
| if cfg.backend == "fastembed": | |
| return FastEmbedWrapper(cfg) | |
| if cfg.backend == "libembedding": | |
| return LibEmbedWrapper(cfg) | |
| return SBertWrapper(cfg) | |