| |
| """ |
| Interactive chat script for any model with automatic chat template support. |
| Usage: python chat_with_models.py <model_folder_name> [--assistant] |
| """ |
|
|
| import os |
| import sys |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer, StoppingCriteria, StoppingCriteriaList |
| import warnings |
| import argparse |
|
|
| |
| warnings.filterwarnings("ignore") |
|
|
| class StopSequenceCriteria(StoppingCriteria): |
| def __init__(self, tokenizer, stop_sequences, prompt_length): |
| self.tokenizer = tokenizer |
| self.stop_sequences = stop_sequences |
| self.prompt_length = prompt_length |
| self.triggered_stop_sequence = None |
| |
| def __call__(self, input_ids, scores, **kwargs): |
| |
| if input_ids.shape[1] <= self.prompt_length: |
| return False |
| |
| |
| new_tokens = input_ids[0][self.prompt_length:] |
| new_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
| |
| |
| for stop_seq in self.stop_sequences: |
| if stop_seq in new_text: |
| return True |
| return False |
|
|
| class ModelChatter: |
| def __init__(self, model_folder, force_assistant_template=False): |
| self.model_folder = model_folder |
| self.hf_path = os.path.join(model_folder, 'hf') |
| self.model = None |
| self.tokenizer = None |
| self.pipeline = None |
| self.conversation_history = [] |
| self.force_assistant_template = force_assistant_template |
| |
| def load_model(self): |
| """Load the model and tokenizer.""" |
| try: |
| print(f"π Loading {self.model_folder}...") |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(self.hf_path) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| |
| if self.force_assistant_template: |
| print(f"π Forcing User: Assistant: chat template...") |
| custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}User: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAssistant: ' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAssistant: ' }}{% endif %}""" |
| self.tokenizer.chat_template = custom_template |
| print(f"β
User: Assistant: chat template forced") |
| elif not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None: |
| print(f"π No chat template found, assigning custom template...") |
| custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}Instruction: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAnswer:' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAnswer:' }}{% endif %}""" |
| self.tokenizer.chat_template = custom_template |
| print(f"β
Custom chat template assigned") |
| else: |
| print(f"β
Model has existing chat template") |
| |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| self.hf_path, |
| device_map=None, |
| torch_dtype=torch.float16, |
| trust_remote_code=True |
| ) |
| |
| |
| if torch.cuda.is_available(): |
| self.model.to("cuda:0") |
| device = "cuda:0" |
| elif torch.backends.mps.is_available(): |
| self.model.to("mps") |
| device = "mps" |
| else: |
| self.model.to("cpu") |
| device = "cpu" |
| |
| print(f" π± Using device: {device}") |
| |
| |
| self.pipeline = pipeline( |
| "text-generation", |
| model=self.model, |
| tokenizer=self.tokenizer, |
| device_map="auto", |
| torch_dtype=torch.float16 |
| ) |
| |
| print(f" β
{self.model_folder} loaded successfully") |
| return True |
| |
| except Exception as e: |
| print(f" β Failed to load {self.model_folder}: {str(e)}") |
| return False |
| |
| def format_chat_prompt(self, user_message): |
| """Format the conversation history and new user message using the chat template.""" |
| |
| self.conversation_history.append({"role": "user", "content": user_message}) |
| |
| |
| try: |
| formatted_prompt = self.tokenizer.apply_chat_template( |
| self.conversation_history, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| return formatted_prompt |
| except Exception as e: |
| print(f"β Error formatting chat prompt: {str(e)}") |
| return None |
| |
| def generate_response(self, user_message, max_length=512): |
| """Generate a response to the user message.""" |
| try: |
| |
| formatted_prompt = self.format_chat_prompt(user_message) |
| if formatted_prompt is None: |
| return "β Failed to format chat prompt" |
| |
| |
| print("π€ Response: ", end="", flush=True) |
| |
| |
| inputs = self.tokenizer(formatted_prompt, return_tensors="pt") |
| if torch.cuda.is_available(): |
| inputs = {k: v.to("cuda:0") for k, v in inputs.items()} |
| elif torch.backends.mps.is_available(): |
| inputs = {k: v.to("mps") for k, v in inputs.items()} |
| |
| |
| streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) |
| |
| |
| stop_sequences = ["Question:", "Instruction:", "Answer:", "User:"] |
| |
| |
| prompt_length = inputs['input_ids'].shape[1] |
| stopping_criteria = StopSequenceCriteria(self.tokenizer, stop_sequences, prompt_length) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_length, |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.9, |
| repetition_penalty=1.1, |
| pad_token_id=self.tokenizer.eos_token_id, |
| streamer=streamer, |
| eos_token_id=self.tokenizer.eos_token_id, |
| stopping_criteria=StoppingCriteriaList([stopping_criteria]) |
| ) |
| |
| |
| generated_text = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) |
| |
| |
| if stopping_criteria.triggered_stop_sequence: |
| stop_seq = stopping_criteria.triggered_stop_sequence |
| original_text = generated_text |
| if generated_text.endswith(stop_seq): |
| generated_text = generated_text[:-len(stop_seq)].rstrip() |
| elif stop_seq in generated_text: |
| |
| last_pos = generated_text.rfind(stop_seq) |
| if last_pos != -1: |
| generated_text = generated_text[:last_pos].rstrip() |
| |
| |
| if generated_text != original_text: |
| print(f"\nπ Stripped stop sequence '{stop_seq}' from response") |
| |
| |
| self.conversation_history.append({"role": "assistant", "content": generated_text}) |
| |
| |
| return "" |
| |
| except Exception as e: |
| return f"β Generation failed: {str(e)}" |
| |
| def reset_conversation(self): |
| """Reset the conversation history.""" |
| self.conversation_history = [] |
| print("π Conversation history cleared!") |
| |
| def show_conversation_history(self): |
| """Display the current conversation history.""" |
| if not self.conversation_history: |
| print("π No conversation history yet.") |
| return |
| |
| print("\nπ Conversation History:") |
| print("=" * 50) |
| for i, message in enumerate(self.conversation_history): |
| role = message["role"].capitalize() |
| content = message["content"] |
| print(f"{role}: {content}") |
| if i < len(self.conversation_history) - 1: |
| print("-" * 30) |
| print("=" * 50) |
| |
| def interactive_chat(self): |
| """Main interactive chat loop.""" |
| print(f"\n㪠Chatting with {self.model_folder}") |
| print("Commands:") |
| print(" - Type your message to chat") |
| print(" - Type 'quit' or 'exit' to end") |
| print(" - Type 'help' for this message") |
| print(" - Type 'reset' to clear conversation history") |
| print(" - Type 'history' to show conversation history") |
| print(" - Type 'clear' to clear screen") |
| print("\nπ‘ Start chatting! (Works with any model)") |
| |
| while True: |
| try: |
| user_input = input("\nπ€ You: ").strip() |
| |
| if not user_input: |
| continue |
| |
| if user_input.lower() in ['quit', 'exit', 'q']: |
| print("π Goodbye!") |
| break |
| |
| elif user_input.lower() == 'help': |
| print(f"\n㪠Chatting with {self.model_folder}") |
| print("Commands:") |
| print(" - Type your message to chat") |
| print(" - Type 'quit' or 'exit' to end") |
| print(" - Type 'help' for this message") |
| print(" - Type 'reset' to clear conversation history") |
| print(" - Type 'history' to show conversation history") |
| print(" - Type 'clear' to clear screen") |
| print(" - Works with any model (auto-assigns chat template)") |
| |
| elif user_input.lower() == 'reset': |
| self.reset_conversation() |
| |
| elif user_input.lower() == 'history': |
| self.show_conversation_history() |
| |
| elif user_input.lower() == 'clear': |
| os.system('clear' if os.name == 'posix' else 'cls') |
| |
| else: |
| |
| print(f"\nπ€ {self.model_folder}:") |
| response = self.generate_response(user_input) |
| |
| |
| except KeyboardInterrupt: |
| print("\n\nπ Goodbye!") |
| break |
| except Exception as e: |
| print(f"β Error: {str(e)}") |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Interactive chat script for any model") |
| parser.add_argument("model_folder", help="Name of the model folder") |
| parser.add_argument("--assistant", action="store_true", |
| help="Force User: Assistant: chat template even if model has its own") |
| |
| args = parser.parse_args() |
| |
| model_folder = args.model_folder |
| force_assistant_template = args.assistant |
| |
| |
| if not os.path.exists(model_folder): |
| print(f"β Model folder '{model_folder}' not found!") |
| sys.exit(1) |
| |
| |
| hf_path = os.path.join(model_folder, 'hf') |
| if not os.path.exists(hf_path): |
| print(f"β No 'hf' subdirectory found in '{model_folder}'!") |
| sys.exit(1) |
| |
| print("π Model Chat Script") |
| print("=" * 50) |
| if force_assistant_template: |
| print("π§ Forcing User: Assistant: chat template") |
| print("=" * 50) |
| |
| chatter = ModelChatter(model_folder, force_assistant_template) |
| |
| |
| if not chatter.load_model(): |
| print("β Failed to load model. Exiting.") |
| sys.exit(1) |
| |
| print(f"β
Model '{model_folder}' loaded successfully") |
| |
| |
| chatter.interactive_chat() |
|
|
| if __name__ == "__main__": |
| main() |