File size: 15,337 Bytes
762d748 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 | import dataclasses
import inspect
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.models.tokenizer import Tokenizer
if TYPE_CHECKING:
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer
from outlines.processors import OutlinesLogitsProcessor
__all__ = ["transformers"]
KVCacheType = Tuple[Tuple["torch.DoubleTensor", "torch.DoubleTensor"], ...]
def get_llama_tokenizer_types():
"""Get all the Llama tokenizer types/classes that need work-arounds.
When they can't be imported, a dummy class is created.
"""
try:
from transformers.models.llama import LlamaTokenizer
except ImportError:
class LlamaTokenizer: # type: ignore
pass
try:
from transformers.models.llama import LlamaTokenizerFast
except ImportError:
class LlamaTokenizerFast: # type: ignore
pass
try:
from transformers.models.code_llama import CodeLlamaTokenizer
except ImportError:
class CodeLlamaTokenizer: # type: ignore
pass
try:
from transformers.models.code_llama import CodeLlamaTokenizerFast
except ImportError:
class CodeLlamaTokenizerFast: # type: ignore
pass
return (
LlamaTokenizer,
LlamaTokenizerFast,
CodeLlamaTokenizer,
CodeLlamaTokenizerFast,
)
class TransformerTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""
def __init__(self, tokenizer: "PreTrainedTokenizer", **kwargs):
self.tokenizer = tokenizer
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.pad_token_id = self.eos_token_id
else:
self.pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self.tokenizer.pad_token
self.special_tokens = set(self.tokenizer.all_special_tokens)
self.vocabulary = self.tokenizer.get_vocab()
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())
def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple["torch.LongTensor", "torch.LongTensor"]:
kwargs["padding"] = True
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]
def decode(self, token_ids: "torch.LongTensor") -> List[str]:
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return text
def convert_token_to_string(self, token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = self.tokenizer.convert_tokens_to_string([token])
if self.is_llama:
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
def __eq__(self, other):
if isinstance(other, type(self)):
if hasattr(self, "model_name") and hasattr(self, "kwargs"):
return (
other.model_name == self.model_name and other.kwargs == self.kwargs
)
else:
return other.tokenizer == self.tokenizer
return NotImplemented
def __hash__(self):
from datasets.fingerprint import Hasher
return hash(Hasher.hash(self.tokenizer))
def __getstate__(self):
state = {"tokenizer": self.tokenizer}
return state
def __setstate__(self, state):
self.__init__(state["tokenizer"])
class Transformers:
"""Represents a `transformers` model."""
def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
):
self.model = model
self.tokenizer = TransformerTokenizer(tokenizer)
def forward(
self,
input_ids: "torch.LongTensor",
attention_mask: "torch.LongTensor",
past_key_values: Optional[Tuple] = None,
) -> Tuple["torch.FloatTensor", Optional[KVCacheType]]:
"""Compute a forward pass through the transformer model.
Parameters
----------
input_ids
The input token ids. Must be one or two dimensional.
attention_mask
The attention mask. Must be one or two dimensional.
past_key_values
A tuple of tuples containing the cached key and value tensors for each
attention head.
Returns
-------
The computed logits and the new cached key and value tensors.
"""
try:
import torch
except ImportError:
ImportError(
"The `torch` library needs to be installed to use `transformers` models."
)
assert 0 < input_ids.ndim < 3
if past_key_values:
input_ids = input_ids[..., -1].unsqueeze(-1)
with torch.inference_mode():
output = self.model(
input_ids,
attention_mask=attention_mask,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
past_key_values=past_key_values,
)
return output.logits, output.past_key_values
def __call__(
self,
input_ids: "torch.LongTensor",
attention_mask: "torch.LongTensor",
past_key_values: Optional[Tuple] = None,
) -> "torch.FloatTensor":
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
next_token_logits = logits[..., -1, :]
return next_token_logits, kv_cache
def generate(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
logits_processor: Optional["OutlinesLogitsProcessor"],
sampling_parameters: SamplingParameters,
) -> Union[str, List[str], List[List[str]]]:
"""Generate text using `transformers`.
Arguments
---------
prompts
A prompt or list of prompts.
generation_parameters
An instance of `GenerationParameters` that contains the prompt,
the maximum number of tokens, stop sequences and seed. All the
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
logits_processor
The logits processor to use when generating text.
sampling_parameters
An instance of `SamplingParameters`, a dataclass that contains
the name of the sampler to use and related parameters as available
in Outlines.
Returns
-------
The generated text
"""
if isinstance(prompts, str):
# convert to 2d
input_ids, attention_mask = self.tokenizer.encode([prompts])
else:
input_ids, attention_mask = self.tokenizer.encode(prompts)
inputs = {
"input_ids": input_ids.to(self.model.device),
"attention_mask": attention_mask.to(self.model.device),
}
if (
"attention_mask"
not in inspect.signature(self.model.forward).parameters.keys()
):
del inputs["attention_mask"]
generation_kwargs = self._get_generation_kwargs(
prompts,
generation_parameters,
logits_processor,
sampling_parameters,
)
generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs)
# if single str input and single sample per input, convert to a 1D output
if isinstance(prompts, str):
generated_ids = generated_ids.squeeze(0)
return self._decode_generation(generated_ids)
def stream(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
logits_processor: Optional["OutlinesLogitsProcessor"],
sampling_parameters: SamplingParameters,
) -> Iterator[Union[str, List[str]]]:
"""
Temporary stream stand-in which implements stream() signature
and equivalent behaviour but isn't yielded until generation completes.
TODO: implement following completion of https://github.com/huggingface/transformers/issues/30810
"""
if isinstance(prompts, str):
# convert to 2d
input_ids, attention_mask = self.tokenizer.encode([prompts])
else:
input_ids, attention_mask = self.tokenizer.encode(prompts)
inputs = {
"input_ids": input_ids.to(self.model.device),
"attention_mask": attention_mask.to(self.model.device),
}
if (
"attention_mask"
not in inspect.signature(self.model.forward).parameters.keys()
):
del inputs["attention_mask"]
generation_kwargs = self._get_generation_kwargs(
prompts,
generation_parameters,
logits_processor,
sampling_parameters,
)
generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs)
# if single str input and single sample per input, convert to a 1D output
if isinstance(prompts, str):
generated_ids = generated_ids.squeeze(0)
for i in range(generated_ids.size(-1)):
output_group_ids = generated_ids.select(-1, i).unsqueeze(-1)
yield self._decode_generation(output_group_ids)
def _get_generation_kwargs(
self,
prompts: Union[str, List[str]],
generation_parameters: GenerationParameters,
logits_processor: Optional["OutlinesLogitsProcessor"],
sampling_parameters: SamplingParameters,
) -> dict:
"""
Conert outlines generation parameters into model.generate kwargs
"""
from transformers import GenerationConfig, LogitsProcessorList, set_seed
max_new_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)
sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple(
sampling_parameters
)
if max_new_tokens is None:
max_new_tokens = int(2**30)
# global seed, not desirable
if seed is not None:
set_seed(seed)
if logits_processor is not None:
logits_processor_list = LogitsProcessorList([logits_processor])
else:
logits_processor_list = None
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
stop_strings=stop_at,
num_return_sequences=(num_samples or 1),
top_p=top_p,
top_k=top_k,
temperature=temperature,
do_sample=(sampler == "multinomial"),
num_beams=(num_samples if sampler == "beam_search" else 1),
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
return dict(
logits_processor=logits_processor_list,
generation_config=generation_config,
tokenizer=self.tokenizer.tokenizer,
)
def _generate_output_seq(
self, prompts, inputs, generation_config, **generation_kwargs
):
input_ids = inputs["input_ids"]
output_ids = self.model.generate(
**inputs, generation_config=generation_config, **generation_kwargs
)
# encoder-decoder returns output_ids only, decoder-only returns full seq ids
if self.model.config.is_encoder_decoder:
generated_ids = output_ids
else:
generated_ids = output_ids[:, input_ids.shape[1] :]
# if batch list inputs AND multiple samples per input, convert generated_id to 3D view
num_samples = generation_config.num_return_sequences or 1
if num_samples > 1 and isinstance(prompts, list):
batch_size = input_ids.size(0)
num_return_sequences = generation_config.num_return_sequences or 1
generated_ids = generated_ids.view(batch_size, num_return_sequences, -1)
return generated_ids
def _decode_generation(self, generated_ids: "torch.Tensor"):
if len(generated_ids.shape) == 1:
return self.tokenizer.decode([generated_ids])[0]
elif len(generated_ids.shape) == 2:
return self.tokenizer.decode(generated_ids)
elif len(generated_ids.shape) == 3:
return [
self.tokenizer.decode(generated_ids[i])
for i in range(len(generated_ids))
]
else:
raise TypeError(
f"Generated outputs aren't 1D, 2D or 3D, but instead are {generated_ids.shape}"
)
def transformers(
model_name: str,
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
model_class=None,
tokenizer_class=None,
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Parameters
----------
model_name
The name of the model as listed on Hugging Face's model page.
device
The device(s) on which the model should be loaded. This overrides
the `device_map` entry in `model_kwargs` when provided.
model_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the model.
tokenizer_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the tokenizer.
Returns
-------
A `TransformersModel` model instance.
"""
if model_class is None or tokenizer_class is None:
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if model_class is None:
model_class = AutoModelForCausalLM
if tokenizer_class is None:
tokenizer_class = AutoTokenizer
if device is not None:
model_kwargs["device_map"] = device
model = model_class.from_pretrained(model_name, **model_kwargs)
tokenizer_kwargs.setdefault("padding_side", "left")
tokenizer = tokenizer_class.from_pretrained(model_name, **tokenizer_kwargs)
return Transformers(model, tokenizer)
def mamba(
model_name: str,
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
):
try:
from transformers import MambaForCausalLM
except ImportError:
raise ImportError(
"The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba."
)
return transformers(
model_name=model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
model_class=MambaForCausalLM,
)
|