koichi12's picture
Add files using upload-large-folder tool
762d748 verified
raw
history blame
8.59 kB
import dataclasses
from typing import TYPE_CHECKING, Generator, Iterator, List, Optional, Tuple, Union
from .transformers import TransformerTokenizer
if TYPE_CHECKING:
import mlx.core as mx
import mlx.nn as nn
from transformers import PreTrainedTokenizer
from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.processors import OutlinesLogitsProcessor
class MLXLM:
"""
Represents an `mlx_lm` model
"""
def __init__(
self,
model: "nn.Module",
tokenizer: "PreTrainedTokenizer",
):
self.model = model
self.mlx_tokenizer = tokenizer # returns mlx tensors, used for encode()
self.tokenizer = TransformerTokenizer(
tokenizer._tokenizer
) # _tokenizer is HF Tokenizer
def generate(
self,
prompts: Union[str, List[str]],
generation_parameters: "GenerationParameters",
logits_processor,
sampling_parameters: "SamplingParameters",
) -> str:
streamer = self.stream(
prompts, generation_parameters, logits_processor, sampling_parameters
)
return "".join(list(streamer))
def stream(
self,
prompts: Union[str, List[str]],
generation_parameters: "GenerationParameters",
logits_processor,
sampling_parameters: "SamplingParameters",
) -> Iterator[str]:
"""Generate text using `mlx_lm`.
Arguments
---------
prompts
A prompt or list of prompts.
generation_parameters
An instance of `GenerationParameters` that contains the prompt,
the maximum number of tokens, stop sequences and seed. All the
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
logits_processor
The logits processor to use when generating text.
sampling_parameters
An instance of `SamplingParameters`, a dataclass that contains
the name of the sampler to use and related parameters as available
in Outlines.
Returns
-------
The generated text.
"""
import mlx.core as mx
max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)
sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple(
sampling_parameters
)
if max_tokens is None:
max_tokens = int(1e9)
if not isinstance(prompts, str):
raise NotImplementedError(
"The `mlx-lm` library does not support batch inference."
)
if sampler == "beam_search":
raise NotImplementedError(
"The `mlx-lm` library does not support Beam Search."
)
if num_samples != 1:
raise NotImplementedError(
"The `mlx-lm` library does not allow to take several samples."
)
if top_k is not None:
raise NotImplementedError("The `mlx-lm` library does not support top_k.")
if seed is not None:
raise NotImplementedError("The `mlx-lm` library does not support seed.")
if stop_at is not None:
raise NotImplementedError("The `mlx-lm` library does not support stop_at.")
generate_kwargs = {
"temp": temperature,
"top_p": top_p,
"sampler": sampler,
"logits_processor": logits_processor,
}
# Adapted from
# https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L267
prompt_tokens = mx.array(self.mlx_tokenizer.encode(prompts))
detokenizer = self.mlx_tokenizer.detokenizer
detokenizer.reset()
for (token, prob), n in zip(
self.generate_step(prompt_tokens, **generate_kwargs),
range(max_tokens),
):
if token == self.tokenizer.eos_token_id:
break
detokenizer.add_token(token)
yield detokenizer.last_segment
detokenizer.finalize()
yield detokenizer.last_segment
def generate_step(
self,
prompt: "mx.array",
temp: Optional[float],
top_p: Optional[float],
sampler: str,
logits_processor: "OutlinesLogitsProcessor",
) -> Generator[Tuple[int, float], None, None]:
"""
Adapted from
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
sampler (str): The sampler string defined by SequenceGeneratorAdapter
logits_processor (OutlinesLogitsProcessor): Augment logits before sampling.
"""
import mlx.core as mx
import mlx_lm
temperature: float = temp or 1.0
def sample(logits: "mx.array") -> Tuple["mx.array", float]:
softmax_logits = mx.softmax(logits)
if temperature == 0.0 or sampler == "greedy":
token = mx.argmax(logits, axis=-1)
elif sampler == "multinomial":
if top_p is not None and top_p > 0 and top_p < 1.0:
token = mlx_lm.sample_utils.top_p_sampling(
logits, top_p, temperature
)
else:
token = mx.random.categorical(logits * (1 / temperature))
else:
raise ValueError(f"Invalid mlx-lm sampler: `{sampler}`")
prob = softmax_logits[0, token]
return token, prob
cache = mlx_lm.models.cache.make_prompt_cache(self.model)
# kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model()
unprocessed_input_ids = prompt
generated_ids: List[int] = []
while True:
logits = self.model(unprocessed_input_ids[None], cache=cache)
logits = logits[:, -1, :]
if logits_processor is not None:
# convert to logits_processor 1d expectation, apply, then convert back
logits_1d = logits.reshape(-1)
logits_1d = logits_processor(generated_ids, logits_1d)
logits = logits_1d.reshape(1, -1)
new_token_single, prob = sample(logits)
new_token = new_token_single.item()
yield new_token, prob
generated_ids.append(new_token)
unprocessed_input_ids = new_token_single
def mlxlm(
model_name: str,
tokenizer_config: dict = {},
model_config: dict = {},
adapter_path: Optional[str] = None,
lazy: bool = False,
):
"""Instantiate a model from the `mlx_lm` library and its tokenizer.
Signature adapted from
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L422
Parameters
----------
Args:
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
model_config(dict, optional): Configuration parameters specifically for the model.
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
Returns
-------
A `MLXLM` model instance.
"""
try:
import mlx.core as mx
import mlx_lm
except ImportError:
raise ImportError(
"The `mlx_lm` library needs to be installed in order to use `mlx_lm` models."
)
if not mx.metal.is_available():
raise RuntimeError("You cannot use `mlx_lm` without Apple Silicon (Metal)")
model, tokenizer = mlx_lm.load(
model_name,
tokenizer_config=tokenizer_config,
model_config=model_config,
adapter_path=adapter_path,
lazy=lazy,
)
return MLXLM(model, tokenizer)