File size: 7,769 Bytes
762d748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import dataclasses
from typing import TYPE_CHECKING, List, Optional, Union

from outlines.generate.api import GenerationParameters, SamplingParameters

if TYPE_CHECKING:
    from transformers import PreTrainedTokenizerBase
    from vllm import LLM
    from vllm.sampling_params import SamplingParams


class VLLM:
    """Represents a vLLM model.

    We wrap models from model providing libraries in order to give all of
    them the same interface in Outlines and allow users to easily switch
    between providers. This class wraps the `vllm.LLM` class from the
    `vllm` library.

    """

    def __init__(self, model: "LLM"):
        self.model = model
        self.lora_request = None

        self.tokenizer = self._get_tokenizer()

    def _get_tokenizer(self):
        if hasattr(self.model, "get_tokenizer"):
            tokenizer = self.model.get_tokenizer()
        elif hasattr(self.model, "tokenizer"):
            if hasattr(self.model.tokenizer, "tokenizer"):
                tokenizer = self.model.tokenizer.tokenizer
            else:
                tokenizer = self.model.tokenizer
        else:
            raise ValueError(
                "The provided LLM instance neither has a "
                "`tokenizer` attribute or a `get_tokenizer` method."
            )
        return adapt_tokenizer(tokenizer=tokenizer)

    def generate(
        self,
        prompts: Union[str, List[str]],
        generation_parameters: GenerationParameters,
        logits_processor,
        sampling_parameters: SamplingParameters,
        *,
        sampling_params: Optional["SamplingParams"] = None,
        use_tqdm: bool = True,
    ):
        """Generate text using vLLM.

        Arguments
        ---------
        prompts
            A prompt or list of prompts.
        generation_parameters
            An instance of `GenerationParameters` that contains the prompt,
            the maximum number of tokens, stop sequences and seed. All the
            arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
        logits_processor
            The logits processor to use when generating text.
        sampling_parameters
            An instance of `SamplingParameters`, a dataclass that contains
            the name of the sampler to use and related parameters as available
            in Outlines.
        sampling_params
            An instance of `vllm.sampling_params.SamplingParams`. The values
            passed via this dataclass supersede the values of the parameters
            in `generation_parameters` and `sampling_parameters`. See the
            vLLM documentation for more details: https://docs.vllm.ai/en/latest/dev/sampling_params.html.
        use_tqdm
            A boolean in order to display progress bar while inferencing

        Returns
        -------
        The generated text, of shape `(n_batch, n_samples)`. If there are only
        one batch and several samples, the list is of shape `(n_samples)`. If
        this is a batch with several sequences but only one sample the list is
        of shape `(n_batch)`. If there is only one sequence and one sample, a
        string is returned.

        """
        from vllm.sampling_params import SamplingParams

        if sampling_params is None:
            sampling_params = SamplingParams()

        max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)

        # We only update the values in `sampling_params` if they
        # are specified by the user when calling the generator.
        if max_tokens is not None:
            sampling_params.max_tokens = max_tokens
        if stop_at is not None:
            if isinstance(stop_at, str):
                stop_at = [stop_at]
            sampling_params.stop = stop_at
        if seed is not None:
            sampling_params.seed = seed

        sampling_params.logits_processors = (
            [logits_processor] if logits_processor is not None else []
        )

        sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple(
            sampling_parameters
        )

        # We only update the values in `sampling_params` that
        # were not specified by the user.
        if sampling_params.n == 1:
            sampling_params.n = num_samples
            sampling_params.best_of = num_samples
        if top_p is not None and sampling_params.top_p == 1.0:
            sampling_params.top_p = top_p
        if top_k is not None and sampling_params.top_k == -1:
            sampling_params.top_k = top_k
            # TODO: remove this if statement once fixed
            # https://github.com/vllm-project/vllm/issues/5404#issuecomment-2175972897
            if top_k == 1:
                sampling_params.repetition_penalty = 0
        if temperature is not None and sampling_params.temperature == 1.0:
            sampling_params.temperature = temperature
        if sampler == "beam_search":
            sampling_params.use_beam_search = True

        results = self.model.generate(
            prompts,
            sampling_params=sampling_params,
            lora_request=self.lora_request,
            use_tqdm=use_tqdm,
        )
        results = [[sample.text for sample in batch.outputs] for batch in results]

        batch_size = len(results)
        sample_size = len(results[0])

        if batch_size == 1 and sample_size == 1:
            return results[0][0]
        elif batch_size == 1:
            return results[0]
        elif sample_size == 1:
            return [batch[0] for batch in results]

        return results

    def stream(self, *args, **kwargs):
        """Return a text generator.

        Streaming is not yet available for `vllm.LLM`.

        TODO: Implement the streaming functionality ourselves.

        """
        raise NotImplementedError(
            "Streaming is not available for the vLLM integration."
        )

    def load_lora(self, adapter_path: Optional[str]):
        from vllm.lora.request import LoRARequest

        if adapter_path is None:
            self.lora_request = None
        else:
            self.lora_request = LoRARequest(adapter_path, 1, adapter_path)


def vllm(model_name: str, **vllm_model_params):
    """Load a vLLM model.

    Arguments
    ---------
    model_name
        The name of the model to load from the HuggingFace hub.
    vllm_model_params
        vLLM-specific model parameters. See the vLLM code for the full list:
        https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py

    """
    from vllm import LLM

    model = LLM(model_name, **vllm_model_params)

    return VLLM(model)


def adapt_tokenizer(tokenizer: "PreTrainedTokenizerBase") -> "PreTrainedTokenizerBase":
    """Adapt a tokenizer to use to compile the FSM.

    The API of Outlines tokenizers is slightly different to that of `transformers`. In
    addition we need to handle the missing spaces to Llama's tokenizer to be able to
    compile FSMs for this model.

    Parameters
    ----------
    tokenizer
        The tokenizer of the model.

    Returns
    -------
    PreTrainedTokenizerBase
        The adapted tokenizer.
    """
    from transformers import SPIECE_UNDERLINE

    tokenizer.vocabulary = tokenizer.get_vocab()
    tokenizer.special_tokens = set(tokenizer.all_special_tokens)

    def convert_token_to_string(token: Union[str, bytes]) -> str:
        string = tokenizer.convert_tokens_to_string([token])

        # A hack to handle missing spaces to HF's Llama tokenizers
        if (
            type(token) is str
            and token.startswith(SPIECE_UNDERLINE)
            or token == "<0x20>"
        ):
            return " " + string

        return string

    tokenizer.convert_token_to_string = convert_token_to_string

    return tokenizer