File size: 3,735 Bytes
f56dbf3
 
 
 
 
 
 
7c4c461
 
 
 
 
 
 
f56dbf3
 
 
 
 
7c4c461
 
 
 
 
f56dbf3
7c4c461
f56dbf3
7c4c461
 
 
 
 
f56dbf3
 
 
 
 
 
 
 
 
 
7c4c461
f56dbf3
 
 
 
 
7c4c461
 
f56dbf3
 
 
 
 
 
 
 
 
 
 
 
 
 
7c4c461
f56dbf3
 
7c4c461
 
f56dbf3
 
 
 
7052097
 
 
 
 
7c4c461
7052097
 
7c4c461
 
7052097
 
 
 
 
f56dbf3
 
 
 
 
7052097
 
f56dbf3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import annotations

import numpy as np

from models import ModelConfig


def _apply_prefix(cfg: ModelConfig, sentences: list[str], is_query: bool) -> list[str]:
    prefix = cfg.query_prefix if is_query else cfg.passage_prefix
    if not prefix:
        return sentences
    return [prefix + s for s in sentences]


class SBertWrapper:
    """Wraps sentence_transformers.SentenceTransformer."""

    def __init__(self, cfg: ModelConfig):
        from sentence_transformers import SentenceTransformer
        self._cfg = cfg
        load_kwargs: dict = {}
        if cfg.trust_remote_code:
            load_kwargs["trust_remote_code"] = True
        self._model = SentenceTransformer(cfg.model_id, **load_kwargs)

    def encode(self, sentences: list[str], batch_size: int = 64, is_query: bool = False, **kwargs) -> np.ndarray:
        kwargs.setdefault("show_progress_bar", False)
        prompt_name = self._cfg.query_prompt_name if is_query else self._cfg.passage_prompt_name
        if prompt_name and "prompt_name" not in kwargs:
            kwargs["prompt_name"] = prompt_name
        else:
            sentences = _apply_prefix(self._cfg, sentences, is_query)
        return self._model.encode(sentences, batch_size=batch_size, **kwargs)


class GGUFWrapper:
    """Wraps llama_cpp.Llama in embedding mode."""

    def __init__(self, cfg: ModelConfig):
        from huggingface_hub import hf_hub_download
        from llama_cpp import Llama

        self._cfg = cfg
        path = hf_hub_download(repo_id=cfg.model_id, filename=cfg.gguf_file)
        self._model = Llama(
            model_path=path, embedding=True, n_ctx=512, verbose=False
        )

    def encode(self, sentences: list[str], batch_size: int = 64, is_query: bool = False, **kwargs) -> np.ndarray:
        sentences = _apply_prefix(self._cfg, sentences, is_query)
        all_embeddings = []
        for i in range(0, len(sentences), batch_size):
            batch = sentences[i : i + batch_size]
            response = self._model.create_embedding(batch)
            embeddings = [item["embedding"] for item in response["data"]]
            all_embeddings.extend(embeddings)
        return np.array(all_embeddings, dtype=np.float32)


class FastEmbedWrapper:
    """Wraps fastembed.TextEmbedding."""

    def __init__(self, cfg: ModelConfig):
        from fastembed import TextEmbedding
        self._cfg = cfg
        self._model = TextEmbedding(model_name=cfg.model_id)

    def encode(self, sentences: list[str], batch_size: int = 64, is_query: bool = False, **kwargs) -> np.ndarray:
        sentences = _apply_prefix(self._cfg, sentences, is_query)
        embeddings = list(self._model.embed(sentences, batch_size=batch_size))
        return np.array(embeddings, dtype=np.float32)


class LibEmbedWrapper:
    """Wraps libembedding.TextEmbedding."""

    def __init__(self, cfg: ModelConfig):
        from libembedding import TextEmbedding
        self._cfg = cfg
        self._model = TextEmbedding(cfg.model_id)

    def encode(self, sentences: list[str], batch_size: int = 64, is_query: bool = False, **kwargs) -> np.ndarray:
        sentences = _apply_prefix(self._cfg, sentences, is_query)
        embeddings = list(self._model.embed(sentences, batch_size=batch_size))
        return np.array(embeddings, dtype=np.float32)


def load_model(cfg: ModelConfig) -> SBertWrapper | GGUFWrapper | FastEmbedWrapper | LibEmbedWrapper:
    """Factory: returns the right wrapper for the model's backend."""
    if cfg.backend == "gguf":
        return GGUFWrapper(cfg)
    if cfg.backend == "fastembed":
        return FastEmbedWrapper(cfg)
    if cfg.backend == "libembedding":
        return LibEmbedWrapper(cfg)
    return SBertWrapper(cfg)