| import os |
| from dataclasses import dataclass, asdict |
| from ctransformers import AutoModelForCausalLM, AutoConfig |
|
|
|
|
| @dataclass |
| class GenerationConfig: |
| temperature: float |
| top_k: int |
| top_p: float |
| repetition_penalty: float |
| max_new_tokens: int |
| seed: int |
| reset: bool |
| stream: bool |
| threads: int |
| stop: list[str] |
|
|
|
|
| def format_prompt(system_prompt: str, user_prompt: str): |
| """format prompt based on: https://huggingface.co/spaces/mosaicml/mpt-30b-chat/blob/main/app.py""" |
|
|
| system_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n" |
| user_prompt = f"<|im_start|>user\n{user_prompt}<|im_end|>\n" |
| assistant_prompt = f"<|im_start|>assistant\n" |
|
|
| return f"{system_prompt}{user_prompt}{assistant_prompt}" |
|
|
|
|
| def generate( |
| llm: AutoModelForCausalLM, |
| generation_config: GenerationConfig, |
| system_prompt: str, |
| user_prompt: str, |
| ): |
| """run model inference, will return a Generator if streaming is true""" |
|
|
| return llm( |
| format_prompt( |
| system_prompt, |
| user_prompt, |
| ), |
| **asdict(generation_config), |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| config = AutoConfig.from_pretrained("mosaicml/mpt-30b-chat", context_length=8192) |
| llm = AutoModelForCausalLM.from_pretrained( |
| os.path.abspath("models/mpt-30b-chat.ggmlv0.q4_1.bin"), |
| model_type="mpt", |
| config=config, |
| ) |
|
|
| system_prompt = "A conversation between a user and an LLM-based AI assistant named Local Assistant. Local Assistant gives helpful and honest answers." |
|
|
| generation_config = GenerationConfig( |
| temperature=0.2, |
| top_k=0, |
| top_p=0.9, |
| repetition_penalty=1.0, |
| max_new_tokens=512, |
| seed=42, |
| reset=False, |
| stream=True, |
| threads=int(os.cpu_count() / 2), |
| stop=["<|im_end|>", "|<"], |
| ) |
|
|
| user_prefix = "[user]: " |
| assistant_prefix = f"[assistant]:" |
|
|
| while True: |
| user_prompt = input(user_prefix) |
| generator = generate(llm, generation_config, system_prompt, user_prompt.strip()) |
| print(assistant_prefix, end=" ", flush=True) |
| for word in generator: |
| print(word, end="", flush=True) |
| print("") |
|
|