| import logging |
| from typing import Union, List, Optional, Dict, Any, Literal |
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer |
| import transformers |
| from transformers_neuronx import MistralForSampling, GQA, NeuronConfig, QuantizationConfig |
| import time |
| import math |
| import concurrent.futures |
|
|
|
|
| def padding_ceiling(n): |
| if n <= 0: |
| return 1 |
| elif n & (n - 1) == 0: |
| return n |
| else: |
| return 2 ** math.ceil(math.log2(n)) |
|
|
|
|
| class MyStreamer(transformers.generation.streamers.BaseStreamer): |
| def __init__(self) -> None: |
| self.reset() |
|
|
| def reset(self): |
| self.token_latencies = [] |
| self.iter = 0 |
| self.now = time.time() |
|
|
| def put(self, tokens): |
| now = time.time() |
| token_latency = now - self.now |
| self.now = now |
| self.iter += 1 |
| self.token_latencies.append(token_latency) |
|
|
| def end(self): |
| print("\n\n") |
| print("First 5 token latencies:", self.token_latencies[:5]) |
| print("All token latencies:", sum(self.token_latencies[:])) |
|
|
|
|
| class MistralModel: |
| """ |
| A class for generating text using the Mistral language model. |
| """ |
|
|
| def __init__(self, model_name): |
| self.neuron_config = NeuronConfig(group_query_attention=GQA.SHARD_OVER_HEADS, |
| quant=QuantizationConfig(quant_dtype='s8', dequant_dtype='bf16')) |
| |
| self.model_name = model_name |
| self.amp: Literal['bf16', 'fp32'] = 'bf16' |
| self.batch_size = 1 |
| self.tp_degree = 2 |
| self.n_positions = 4096 |
| self.context_length_estimate = [2289, 4096] |
| |
|
|
| self.model = self._load_model() |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.prompt_template = "<s>[INST] {prompt} [/INST]" |
|
|
| def _load_model(self) -> MistralForSampling: |
| """ |
| Load and initialize the Mistral model. |
| |
| Returns: |
| MistralForSampling: The initialized Mistral model. |
| """ |
| model = MistralForSampling.from_pretrained( |
| self.model_name, |
| amp=self.amp, |
| batch_size=self.batch_size, |
| tp_degree=self.tp_degree, |
| n_positions=self.n_positions, |
| neuron_config=self.neuron_config, |
| context_length_estimate=self.context_length_estimate, |
| |
| ) |
| model.to_neuron() |
| return model |
|
|
| def generate(self, inputs: Union[str, List[int]], parameters: Optional[Dict[str, Any]] = None) -> str: |
| """ |
| Generate text using the Mistral model. |
| |
| Args: |
| inputs (Union[str, List[int]]): The input prompt or a list of input embeddings. |
| parameters (Optional[Dict[str, Any]]): Optional parameters for text generation. |
| |
| Returns: |
| str: The generated text. |
| |
| Raises: |
| ValueError: If the input type is invalid. |
| """ |
| try: |
| max_new_tokens = parameters.get("max_new_tokens", 256) |
| top_k = parameters.get("top_k", 100) |
| top_p = parameters.get("top_p", 0.1) |
| temperature = parameters.get("temperature", 0.1) |
| no_repeat_ngram_size = parameters.get("no_repeat_ngram_size", 3) |
| print( |
| f"parameters max_new_tokens: {max_new_tokens}, top_k: {top_k}, top_p: {top_p}, temperature: {temperature}, no_repeat_ngram_size: {no_repeat_ngram_size}") |
|
|
| if isinstance(inputs, str): |
| generated_text = self._generate_from_prompt(inputs, max_new_tokens, top_k, top_p, temperature, |
| no_repeat_ngram_size) |
| elif isinstance(inputs, list): |
| generated_text = self._generate_from_embeddings(inputs, max_new_tokens, top_k, top_p, temperature, |
| no_repeat_ngram_size) |
| else: |
| raise ValueError("Invalid input type. Must be str or List[int]") |
|
|
| return generated_text |
| except Exception as e: |
| logging.error(f"Error generating text: {e}") |
| raise |
|
|
| def _generate_from_prompt(self, prompt: str, max_new_tokens: int, top_k: float, top_p: float, temperature: float, |
| no_repeat_ngram_size: int) -> str: |
| """ |
| Generate text from a given prompt using the Mistral model. |
| |
| Args: |
| prompt (str): The input prompt. |
| max_new_tokens (int): The maximum number of new tokens to generate. |
| |
| Returns: |
| str: The generated text. |
| """ |
| input_prompt = self.prompt_template.format(prompt=prompt) |
| encoded_input = self.tokenizer(input_prompt, return_tensors='pt') |
| input_ids = encoded_input.input_ids |
|
|
| with torch.inference_mode(): |
| generated_sequence = self.model.sample(input_ids, sequence_length=min(self.n_positions, |
| input_ids.shape[1] + max_new_tokens), |
| start_ids=None, top_k=top_k, top_p=top_p, temperature=temperature, |
| no_repeat_ngram_size=no_repeat_ngram_size) |
| decoded_output = [self.tokenizer.decode(tok) for tok in generated_sequence] |
|
|
| generated_text = decoded_output[0].split('[/INST]')[1].strip("</s>").strip() |
| return generated_text |
|
|
| def _generate_from_embeddings(self, input_embeddings: List[int], max_new_tokens: int, top_k: float, top_p: float, |
| temperature: float, no_repeat_ngram_size: int) -> str: |
| """ |
| Generate text from a given list of input embeddings using the Mistral model. |
| |
| Args: |
| input_embeddings (List[int]): A list of input embeddings. |
| max_new_tokens (int): The maximum number of new tokens to generate. |
| |
| Returns: |
| str: The generated text. |
| """ |
| s1 = time.time() |
| input_embeds_tensor = torch.tensor(input_embeddings) |
| input_embeds_length = input_embeds_tensor.shape[1] |
| padding_size = padding_ceiling(input_embeds_length) |
| if padding_size >= self.n_positions: |
| padding_size = input_embeds_length |
| padded_input_embeds = input_embeds_tensor |
| else: |
| padding_gap = padding_size - input_embeds_length |
| padded_input_embeds = F.pad(input_embeds_tensor, (0, 0, padding_gap, 0), value=self.tokenizer.pad_token_id) |
| print("ms1 - input_embeds time: ", time.time() - s1) |
|
|
| s2 = time.time() |
| with torch.inference_mode(): |
| generated_sequence = self.model.sample(padded_input_embeds, |
| sequence_length=min(self.n_positions, padding_size + max_new_tokens), |
| start_ids=None, top_k=top_k, top_p=top_p, temperature=temperature, |
| no_repeat_ngram_size=no_repeat_ngram_size, streamer=MyStreamer()) |
| with concurrent.futures.ThreadPoolExecutor() as executor: |
| decoded_output = list(executor.map(self.tokenizer.decode, generated_sequence)) |
| |
| print("ms2 - decoded_output time: ", time.time() - s2) |
|
|
| generated_text = decoded_output[0].strip("</s>").strip() |
| return generated_text |
|
|
|
|