| from functools import partial |
|
|
| from langchain.llms.base import LLM |
| from langchain.callbacks.manager import CallbackManagerForLLMRun |
| from typing import Any, Dict, List, Optional |
| from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig |
| from exllama.tokenizer import ExLlamaTokenizer |
| from exllama.generator import ExLlamaGenerator |
| from exllama.lora import ExLlamaLora |
| import os, glob |
|
|
| from pydantic.v1 import root_validator |
|
|
| BROKEN_UNICODE = b'\\ufffd'.decode('unicode_escape') |
|
|
| class H2OExLlamaTokenizer(ExLlamaTokenizer): |
| def __call__(self, text, *args, **kwargs): |
| return dict(input_ids=self.encode(text)) |
|
|
|
|
| class H2OExLlamaGenerator(ExLlamaGenerator): |
| def is_exlama(self): |
| return True |
|
|
|
|
| class Exllama(LLM): |
| client: Any |
| model_path: str = None |
| model: Any = None |
| sanitize_bot_response: bool = False |
| prompter: Any = None |
| context: Any = '' |
| iinput: Any = '' |
|
|
| """The path to the GPTQ model folder.""" |
| exllama_cache: ExLlamaCache = None |
| config: ExLlamaConfig = None |
| generator: ExLlamaGenerator = None |
| tokenizer: ExLlamaTokenizer = None |
|
|
| |
| logfunc = print |
| stop_sequences: Optional[List[str]] = "" |
| streaming: Optional[bool] = True |
|
|
| |
| disallowed_tokens: Optional[List[int]] = None |
| temperature: Optional[float] = None |
| top_k: Optional[int] = None |
| top_p: Optional[float] = None |
| min_p: Optional[float] = None |
| typical: Optional[float] = None |
| token_repetition_penalty_max: Optional[float] = None |
| token_repetition_penalty_sustain: Optional[int] = None |
| token_repetition_penalty_decay: Optional[int] = None |
| beams: Optional[int] = None |
| beam_length: Optional[int] = None |
|
|
| |
| max_seq_len: Optional[int] = 2048 |
| compress_pos_emb: Optional[float] = 1.0 |
| set_auto_map: Optional[str] = None |
| gpu_peer_fix: Optional[bool] = None |
| alpha_value: Optional[float] = 1.0 |
|
|
| |
| matmul_recons_thd: Optional[int] = None |
| fused_mlp_thd: Optional[int] = None |
| sdp_thd: Optional[int] = None |
| fused_attn: Optional[bool] = None |
| matmul_fused_remap: Optional[bool] = None |
| rmsnorm_no_half2: Optional[bool] = None |
| rope_no_half2: Optional[bool] = None |
| matmul_no_half2: Optional[bool] = None |
| silu_no_half2: Optional[bool] = None |
| concurrent_streams: Optional[bool] = None |
|
|
| |
| lora_path: Optional[str] = None |
|
|
| @staticmethod |
| def get_model_path_at(path): |
| patterns = ["*.safetensors", "*.bin", "*.pt"] |
| model_paths = [] |
| for pattern in patterns: |
| full_pattern = os.path.join(path, pattern) |
| model_paths = glob.glob(full_pattern) |
| if model_paths: |
| break |
| if model_paths: |
| return model_paths[0] |
| else: |
| return None |
|
|
| @staticmethod |
| def configure_object(params, values, logfunc): |
| obj_params = {k: values.get(k) for k in params} |
|
|
| def apply_to(obj): |
| for key, value in obj_params.items(): |
| if value: |
| if hasattr(obj, key): |
| setattr(obj, key, value) |
| logfunc(f"{key} {value}") |
| else: |
| raise AttributeError(f"{key} does not exist in {obj}") |
|
|
| return apply_to |
|
|
| @root_validator() |
| def validate_environment(cls, values: Dict) -> Dict: |
| model_param_names = [ |
| "temperature", |
| "top_k", |
| "top_p", |
| "min_p", |
| "typical", |
| "token_repetition_penalty_max", |
| "token_repetition_penalty_sustain", |
| "token_repetition_penalty_decay", |
| "beams", |
| "beam_length", |
| ] |
|
|
| config_param_names = [ |
| "max_seq_len", |
| "compress_pos_emb", |
| "gpu_peer_fix", |
| "alpha_value" |
| ] |
|
|
| tuning_parameters = [ |
| "matmul_recons_thd", |
| "fused_mlp_thd", |
| "sdp_thd", |
| "matmul_fused_remap", |
| "rmsnorm_no_half2", |
| "rope_no_half2", |
| "matmul_no_half2", |
| "silu_no_half2", |
| "concurrent_streams", |
| "fused_attn", |
| ] |
|
|
| |
| verbose = values['verbose'] |
| if not verbose: |
| values['logfunc'] = lambda *args, **kwargs: None |
| logfunc = values['logfunc'] |
|
|
| if values['model'] is None: |
| model_path = values["model_path"] |
| lora_path = values["lora_path"] |
|
|
| tokenizer_path = os.path.join(model_path, "tokenizer.model") |
| model_config_path = os.path.join(model_path, "config.json") |
| model_path = Exllama.get_model_path_at(model_path) |
|
|
| config = ExLlamaConfig(model_config_path) |
| tokenizer = ExLlamaTokenizer(tokenizer_path) |
| config.model_path = model_path |
|
|
| configure_config = Exllama.configure_object(config_param_names, values, logfunc) |
| configure_config(config) |
| configure_tuning = Exllama.configure_object(tuning_parameters, values, logfunc) |
| configure_tuning(config) |
|
|
| |
| if values['set_auto_map']: |
| config.set_auto_map(values['set_auto_map']) |
| logfunc(f"set_auto_map {values['set_auto_map']}") |
|
|
| model = ExLlama(config) |
| exllama_cache = ExLlamaCache(model) |
| generator = ExLlamaGenerator(model, tokenizer, exllama_cache) |
|
|
| |
| if lora_path is not None: |
| lora_config_path = os.path.join(lora_path, "adapter_config.json") |
| lora_path = Exllama.get_model_path_at(lora_path) |
| lora = ExLlamaLora(model, lora_config_path, lora_path) |
| generator.lora = lora |
| logfunc(f"Loaded LORA @ {lora_path}") |
| else: |
| generator = values['model'] |
| exllama_cache = generator.cache |
| model = generator.model |
| config = model.config |
| tokenizer = generator.tokenizer |
|
|
| |
| configure_model = Exllama.configure_object(model_param_names, values, logfunc) |
| values["stop_sequences"] = [x.strip().lower() for x in values["stop_sequences"]] |
| configure_model(generator.settings) |
|
|
| setattr(generator.settings, "stop_sequences", values["stop_sequences"]) |
| logfunc(f"stop_sequences {values['stop_sequences']}") |
|
|
| disallowed = values.get("disallowed_tokens") |
| if disallowed: |
| generator.disallow_tokens(disallowed) |
| print(f"Disallowed Tokens: {generator.disallowed_tokens}") |
|
|
| values["client"] = model |
| values["generator"] = generator |
| values["config"] = config |
| values["tokenizer"] = tokenizer |
| values["exllama_cache"] = exllama_cache |
|
|
| return values |
|
|
| @property |
| def _llm_type(self) -> str: |
| """Return type of llm.""" |
| return "Exllama" |
|
|
| def get_num_tokens(self, text: str) -> int: |
| """Get the number of tokens present in the text.""" |
| return self.generator.tokenizer.num_tokens(text) |
|
|
| def get_token_ids(self, text: str) -> List[int]: |
| return self.generator.tokenizer.encode(text) |
| |
| |
|
|
| def _call( |
| self, |
| prompt: str, |
| stop: Optional[List[str]] = None, |
| run_manager: Optional[CallbackManagerForLLMRun] = None, |
| **kwargs: Any, |
| ) -> str: |
| assert self.tokenizer is not None |
| from h2oai_pipeline import H2OTextGenerationPipeline |
| prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) |
|
|
| |
| data_point = dict(context=self.context, instruction=prompt, input=self.iinput) |
| prompt = self.prompter.generate_prompt(data_point) |
|
|
| text = '' |
| for text1 in self.stream(prompt=prompt, stop=stop, run_manager=run_manager): |
| text = text1 |
| return text |
|
|
| from enum import Enum |
|
|
| class MatchStatus(Enum): |
| EXACT_MATCH = 1 |
| PARTIAL_MATCH = 0 |
| NO_MATCH = 2 |
|
|
| def match_status(self, sequence: str, banned_sequences: List[str]): |
| sequence = sequence.strip().lower() |
| for banned_seq in banned_sequences: |
| if banned_seq == sequence: |
| return self.MatchStatus.EXACT_MATCH |
| elif banned_seq.startswith(sequence): |
| return self.MatchStatus.PARTIAL_MATCH |
| return self.MatchStatus.NO_MATCH |
|
|
| def stream( |
| self, |
| prompt: str, |
| stop: Optional[List[str]] = None, |
| run_manager: Optional[CallbackManagerForLLMRun] = None, |
| ) -> str: |
| config = self.config |
| generator = self.generator |
| beam_search = (self.beams and self.beams >= 1 and self.beam_length and self.beam_length >= 1) |
|
|
| ids = generator.tokenizer.encode(prompt) |
| generator.gen_begin_reuse(ids) |
|
|
| if beam_search: |
| generator.begin_beam_search() |
| token_getter = generator.beam_search |
| else: |
| generator.end_beam_search() |
| token_getter = generator.gen_single_token |
|
|
| last_newline_pos = 0 |
| seq_length = len(generator.tokenizer.decode(generator.sequence_actual[0])) |
| response_start = seq_length |
| cursor_head = response_start |
|
|
| text_callback = None |
| if run_manager: |
| text_callback = partial( |
| run_manager.on_llm_new_token, verbose=self.verbose |
| ) |
| |
| |
| |
| |
| text = "" |
| while (generator.gen_num_tokens() <= ( |
| self.max_seq_len - 4)): |
| |
| token = token_getter() |
|
|
| |
| if token.item() == generator.tokenizer.eos_token_id: |
| generator.replace_last_token(generator.tokenizer.newline_token_id) |
| if beam_search: |
| generator.end_beam_search() |
| return |
|
|
| |
| stuff = generator.tokenizer.decode(generator.sequence_actual[0][last_newline_pos:]) |
| cursor_tail = len(stuff) |
| has_unicode_combined = cursor_tail<cursor_head |
| text_chunk = stuff[cursor_head:cursor_tail] |
| if has_unicode_combined: |
| |
| text=text[:-2] |
| text_chunk = stuff[cursor_tail-1:cursor_tail] |
| |
| cursor_head = cursor_tail |
|
|
| |
| text += text_chunk |
| text = self.prompter.get_response(prompt + text, prompt=prompt, |
| sanitize_bot_response=self.sanitize_bot_response) |
|
|
| if token.item() == generator.tokenizer.newline_token_id: |
| last_newline_pos = len(generator.sequence_actual[0]) |
| cursor_head = 0 |
| cursor_tail = 0 |
|
|
| |
| status = self.match_status(text, self.stop_sequences) |
|
|
| if status == self.MatchStatus.EXACT_MATCH: |
| |
| rewind_length = generator.tokenizer.encode(text).shape[-1] |
| generator.gen_rewind(rewind_length) |
| |
| if beam_search: |
| generator.end_beam_search() |
| return |
| elif status == self.MatchStatus.PARTIAL_MATCH: |
| |
| continue |
| elif status == self.MatchStatus.NO_MATCH: |
| if text_callback and not (text_chunk == BROKEN_UNICODE): |
| text_callback(text_chunk) |
| yield text |
|
|
| return |
|
|