{ "cells": [ { "cell_type": "raw", "id": "0", "metadata": {}, "source": [ "SPDX-License-Identifier: Apache-2.0\n", "Copyright (c) 2023, Rahul Unnikrishnan Nair \n", "NOTICE: Original was modified to support NVIDIA GPUs" ] }, { "cell_type": "markdown", "id": "1", "metadata": {}, "source": [ "---\n", "**Simple LLM Inference: Playing with Language Models**" ] }, { "cell_type": "markdown", "id": "2", "metadata": {}, "source": [ "Hello and welcome! Are you curious about how computers understand and generate human-like text? Do you want to play around with text generation without getting too technical? Then you've come to the right place.\n", "\n", "Large Language Models (LLMs) have a wide range of applications, but they can also be fun to experiment with. Here, we'll use some simple pre-trained models to explore text generation interactively.\n", "\n", "This notebook provides a hands-on experience that doesn't require deep technical knowledge. Whether you're a student, writer, educator, or just curious about AI, this guide is designed for you.\n", "\n", "Ready to try it out? Let's set up our environment and start exploring the world of text generation with LLMs!\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Skipping import of cpp extensions due to incompatible torch version 2.8.0+cu128 for torchao version 0.14.1 Please see https://github.com/pytorch/ao/issues/2919 for more info\n" ] } ], "source": [ "import os\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"6\"\n", "import logging\n", "import os\n", "import random\n", "import re\n", "\n", "import warnings\n", "\n", "# Suppress warnings for a cleaner output\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "import torch\n", "\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "from transformers import LlamaTokenizer, LlamaForCausalLM\n", "from transformers import BertTokenizer, BertForSequenceClassification\n", "\n", "from ipywidgets import VBox, HBox, Button, Dropdown, IntSlider, FloatSlider, Text, Output, Label, Layout\n", "import ipywidgets as widgets\n", "from ipywidgets import HTML\n", "\n", "\n", "# random seed\n", "if torch.cuda.is_available():\n", " seed = 88\n", " random.seed(seed)\n", " torch.cuda.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", "\n", "def select_device(preferred_device=None):\n", " \"\"\"\n", " Selects the best available cuda device or the preferred device if specified.\n", "\n", " Args:\n", " preferred_device (str, optional): Preferred device string (e.g., \"cpu\", \"cuda\", \"cuda:0\", \"cuda:1\", etc.). If None, a random available cuda device will be selected or CPU if no cuda devices are available.\n", "\n", " Returns:\n", " torch.device: The selected device object.\n", " \"\"\"\n", " try:\n", " if preferred_device and preferred_device.startswith(\"cpu\"):\n", " print(\"Using CPU.\")\n", " return torch.device(\"cpu\")\n", " if preferred_device and preferred_device.startswith(\"cuda\"):\n", " if preferred_device == \"cuda\" or (\n", " \":\" in preferred_device\n", " and int(preferred_device.split(\":\")[1]) >= torch.cuda.device_count()\n", " ):\n", " preferred_device = (\n", " None # Handle as if no preferred device was specified\n", " )\n", " else:\n", " device = torch.device(preferred_device)\n", " if device.type == \"cuda\" and device.index < torch.cuda.device_count():\n", " vram_used = torch.cuda.memory_allocated(device) / (\n", " 1024**2\n", " ) # In MB\n", " print(\n", " f\"Using preferred device: {device}, VRAM used: {vram_used:.2f} MB\"\n", " )\n", " return device\n", "\n", " if torch.cuda.is_available():\n", " device_id = random.choice(\n", " range(torch.cuda.device_count())\n", " ) # Select a random available cuda device\n", " device = torch.device(f\"cuda:{device_id}\")\n", " vram_used = torch.cuda.memory_allocated(device) / (1024**2) # In MB\n", " print(f\"Selected device: {device}, VRAM used: {vram_used:.2f} MB\")\n", " return device\n", " except Exception as e:\n", " print(f\"An error occurred while selecting the device: {e}\")\n", " print(\"No cuda devices available or preferred device not found. Using CPU.\")\n", " return torch.device(\"cpu\")\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "6", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "MODEL_CACHE_PATH = \"~/.cache/huggingface\"\n", "class ChatBotModel:\n", " \"\"\"\n", " ChatBotModel is a class for generating responses based on text prompts using a pretrained model.\n", "\n", " Attributes:\n", " - device: The device to run the model on. Default is \"cuda\" if available, otherwise \"cpu\".\n", " - model: The loaded model for text generation.\n", " - tokenizer: The loaded tokenizer for the model.\n", " - torch_dtype: The data type to use in the model.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " model_id_or_path: str = \"openlm-research/open_llama_3b_v2\", # \"Writer/camel-5b-hf\",\n", " torch_dtype: torch.dtype = torch.bfloat16,\n", " ) -> None:\n", " \"\"\"\n", " The initializer for ChatBotModel class.\n", "\n", " Parameters:\n", " - model_id_or_path: The identifier or path of the pretrained model.\n", " - torch_dtype: The data type to use in the model. Default is torch.bfloat16.\n", " \"\"\"\n", " self.torch_dtype = torch_dtype\n", " self.device = select_device(\"cuda\")\n", " self.model_id_or_path = model_id_or_path\n", " local_model_id = self.model_id_or_path.replace(\"/\", \"--\")\n", " local_model_path = os.path.join(MODEL_CACHE_PATH, local_model_id)\n", "\n", " if (\n", " self.device == self.device.startswith(\"cuda\")\n", " if isinstance(self.device, str)\n", " else self.device.type == \"cuda\"\n", " ):\n", "\n", " self.autocast = torch.cuda.amp.autocast\n", " else:\n", " self.autocast = torch.cpu.amp.autocast\n", " self.torch_dtype = torch_dtype\n", " try:\n", " if \"llama\" in model_id_or_path:\n", " self.tokenizer = LlamaTokenizer.from_pretrained(local_model_path)\n", " self.model = (\n", " LlamaForCausalLM.from_pretrained(\n", " local_model_path,\n", " low_cpu_mem_usage=True,\n", " torch_dtype=self.torch_dtype,\n", " )\n", " .to(self.device)\n", " .eval()\n", " )\n", " else:\n", " self.tokenizer = AutoTokenizer.from_pretrained(\n", " local_model_path, trust_remote_code=True\n", " )\n", " self.model = (\n", " AutoModelForCausalLM.from_pretrained(\n", " local_model_path,\n", " low_cpu_mem_usage=True,\n", " trust_remote_code=True,\n", " torch_dtype=self.torch_dtype,\n", " )\n", " .to(self.device)\n", " .eval()\n", " )\n", " except (OSError, ValueError, EnvironmentError) as e:\n", " logging.info(\n", " f\"Tokenizer / model not found locally. Downloading tokenizer / model for {self.model_id_or_path} to cache...: {e}\"\n", " )\n", " if \"llama\" in model_id_or_path:\n", " self.tokenizer = LlamaTokenizer.from_pretrained(self.model_id_or_path)\n", " self.model = (\n", " LlamaForCausalLM.from_pretrained(\n", " self.model_id_or_path,\n", " low_cpu_mem_usage=True,\n", " torch_dtype=self.torch_dtype,\n", " )\n", " .to(self.device)\n", " .eval()\n", " )\n", " else:\n", " self.tokenizer = AutoTokenizer.from_pretrained(\n", " self.model_id_or_path, trust_remote_code=True\n", " )\n", " self.model = (\n", " AutoModelForCausalLM.from_pretrained(\n", " self.model_id_or_path,\n", " low_cpu_mem_usage=True,\n", " trust_remote_code=True,\n", " torch_dtype=self.torch_dtype,\n", " )\n", " .to(self.device)\n", " .eval()\n", " )\n", " \n", " self.max_length = 256\n", "\n", " def prepare_input(self, previous_text, user_input):\n", " \"\"\"Prepare the input for the model, ensuring it doesn't exceed the maximum length.\"\"\"\n", " response_buffer = 100\n", " user_input = (\n", " \"Below is an instruction that describes a task. \"\n", " \"Write a response that appropriately completes the request.\\n\\n\"\n", " f\"### Instruction:\\n{user_input}\\n\\n### Response:\")\n", " combined_text = previous_text + \"\\nUser: \" + user_input + \"\\nBot: \"\n", " input_ids = self.tokenizer.encode(\n", " combined_text, return_tensors=\"pt\", truncation=False\n", " )\n", " adjusted_max_length = self.max_length - response_buffer\n", " if input_ids.shape[1] > adjusted_max_length:\n", " input_ids = input_ids[:, -adjusted_max_length:]\n", " return input_ids.to(device=self.device)\n", "\n", " def gen_output(\n", " self, input_ids, temperature, top_p, top_k, num_beams, repetition_penalty\n", " ):\n", " \"\"\"\n", " Generate the output text based on the given input IDs and generation parameters.\n", "\n", " Args:\n", " input_ids (torch.Tensor): The input tensor containing token IDs.\n", " temperature (float): The temperature for controlling randomness in Boltzmann distribution.\n", " Higher values increase randomness, lower values make the generation more deterministic.\n", " top_p (float): The cumulative distribution function (CDF) threshold for Nucleus Sampling.\n", " Helps in controlling the trade-off between randomness and diversity.\n", " top_k (int): The number of highest probability vocabulary tokens to keep for top-k-filtering.\n", " num_beams (int): The number of beams for beam search. Controls the breadth of the search.\n", " repetition_penalty (float): The penalty applied for repeating tokens.\n", "\n", " Returns:\n", " torch.Tensor: The generated output tensor.\n", " \"\"\"\n", " print(f\"Using max length: {self.max_length}\")\n", " with self.autocast(\n", " enabled=True if self.torch_dtype != torch.float32 else False,\n", " dtype=self.torch_dtype,\n", " ):\n", " with torch.no_grad():\n", " output = self.model.generate(\n", " input_ids,\n", " do_sample=True,\n", " pad_token_id=self.tokenizer.eos_token_id,\n", " max_length=self.max_length,\n", " temperature=temperature,\n", " top_p=top_p,\n", " top_k=top_k,\n", " num_beams=num_beams,\n", " repetition_penalty=repetition_penalty,\n", " )\n", " return output\n", "\n", " def warmup_model(\n", " self, temperature, top_p, top_k, num_beams, repetition_penalty\n", " ) -> None:\n", " \"\"\"\n", " Warms up the model by generating a sample response.\n", " \"\"\"\n", " sample_prompt = \"\"\"A dialog, where User interacts with a helpful Bot.\n", " AI is helpful, kind, obedient, honest, and knows its own limits.\n", " User: Hello, Bot.\n", " Bot: Hello! How can I assist you today?\n", " \"\"\"\n", " input_ids = self.tokenizer(sample_prompt, return_tensors=\"pt\").input_ids.to(\n", " device=self.device\n", " )\n", " _ = self.gen_output(\n", " input_ids,\n", " temperature=temperature,\n", " top_p=top_p,\n", " top_k=top_k,\n", " num_beams=num_beams,\n", " repetition_penalty=repetition_penalty,\n", " )\n", "\n", " def strip_response(self, generated_text):\n", " \"\"\"Remove ### Response: from string if exists.\"\"\"\n", " match = re.search(r'### Response:(.*)', generated_text, re.S)\n", " if match:\n", " return match.group(1).strip()\n", " \n", " else:\n", " return generated_text\n", " \n", " def unique_sentences(self, text: str) -> str:\n", " sentences = text.split(\". \")\n", " if sentences[-1] and sentences[-1][-1] != \".\":\n", " sentences = sentences[:-1]\n", " sentences = set(sentences)\n", " return \". \".join(sentences) + \".\" if sentences else \"\"\n", "\n", " def remove_repetitions(self, text: str, user_input: str) -> str:\n", " \"\"\"\n", " Remove repetitive sentences or phrases from the generated text and avoid echoing user's input.\n", "\n", " Args:\n", " text (str): The input text with potential repetitions.\n", " user_input (str): The user's original input to check against echoing.\n", "\n", " Returns:\n", " str: The processed text with repetitions and echoes removed.\n", " \"\"\"\n", " text = re.sub(re.escape(user_input), \"\", text, count=1).strip()\n", " text = self.unique_sentences(text)\n", " return text\n", "\n", " def extract_bot_response(self, generated_text: str) -> str:\n", " \"\"\"\n", " Extract the first response starting with \"Bot:\" from the generated text.\n", "\n", " Args:\n", " generated_text (str): The full generated text from the model.\n", "\n", " Returns:\n", " str: The extracted response starting with \"Bot:\".\n", " \"\"\"\n", " prefix = \"Bot:\"\n", " generated_text = generated_text.replace(\"\\n\", \". \")\n", " bot_response_start = generated_text.find(prefix)\n", " if bot_response_start != -1:\n", " response_start = bot_response_start + len(prefix)\n", " end_of_response = generated_text.find(\"\\n\", response_start)\n", " if end_of_response != -1:\n", " return generated_text[response_start:end_of_response].strip()\n", " else:\n", " return generated_text[response_start:].strip()\n", " return re.sub(r'^[^a-zA-Z0-9]+', '', generated_text)\n", "\n", " def interact(\n", " self,\n", " out: Output, # Output widget to display the conversation\n", " with_context: bool = True,\n", " temperature: float = 0.10,\n", " top_p: float = 0.95,\n", " top_k: int = 40,\n", " num_beams: int = 3,\n", " repetition_penalty: float = 1.80,\n", " ) -> None:\n", " \"\"\"\n", " Handle the chat loop where the user provides input and receives a model-generated response.\n", "\n", " Args:\n", " with_context (bool): Whether to consider previous interactions in the session. Default is True.\n", " temperature (float): The temperature for controlling randomness in Boltzmann distribution.\n", " Higher values increase randomness, lower values make the generation more deterministic.\n", " top_p (float): The cumulative distribution function (CDF) threshold for Nucleus Sampling.\n", " Helps in controlling the trade-off between randomness and diversity.\n", " top_k (int): The number of highest probability vocabulary tokens to keep for top-k-filtering.\n", " num_beams (int): The number of beams for beam search. Controls the breadth of the search.\n", " repetition_penalty (float): The penalty applied for repeating tokens.\n", " \"\"\"\n", " previous_text = \"\"\n", " \n", " def display_user_input_widgets():\n", " default_color = \"\\033[0m\"\n", " user_color, user_icon = \"\\033[94m\", \"😀 \"\n", " bot_color, bot_icon = \"\\033[92m\", \"🤖 \"\n", " user_input_widget = Text(placeholder=\"Type your message here...\", layout=Layout(width='80%'))\n", " send_button = Button(description=\"Send\", button_style = \"primary\", layout=Layout(width='10%'))\n", " chat_spin = HTML(value = \"\")\n", " spin_style = \"\"\"\n", "
\n", " \n", " \"\"\"\n", " display(HBox([chat_spin, user_input_widget, send_button, ]))\n", " \n", " def on_send(button):\n", " nonlocal previous_text\n", " send_button.button_style = \"warning\"\n", " chat_spin.value = spin_style\n", " orig_input = \"\"\n", " user_input = user_input_widget.value\n", " with out:\n", " print(f\" {user_color}{user_icon}You: {user_input}{default_color}\")\n", " if user_input.lower() == \"exit\":\n", " return\n", " if \"camel\" in self.model_id_or_path:\n", " orig_input = user_input\n", " user_input = (\n", " \"Below is an instruction that describes a task. \"\n", " \"Write a response that appropriately completes the request.\\n\\n\"\n", " f\"### Instruction:\\n{user_input}\\n\\n### Response:\")\n", " if with_context:\n", " self.max_length = 256\n", " input_ids = self.prepare_input(previous_text, user_input)\n", " else:\n", " self.max_length = 96\n", " input_ids = self.tokenizer.encode(user_input, return_tensors=\"pt\").to(self.device)\n", " \n", " output_ids = self.gen_output(\n", " input_ids,\n", " temperature=temperature,\n", " top_p=top_p,\n", " top_k=top_k,\n", " num_beams=num_beams,\n", " repetition_penalty=repetition_penalty,\n", " )\n", " generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)\n", " generated_text = self.strip_response(generated_text)\n", " generated_text = self.extract_bot_response(generated_text)\n", " generated_text = self.remove_repetitions(generated_text, user_input)\n", " send_button.button_style = \"success\"\n", " chat_spin.value = \"\"\n", "\n", " with out:\n", " if orig_input:\n", " user_input = orig_input\n", " print(f\" {bot_color}{bot_icon}Bot: {generated_text}{default_color}\") \n", " if with_context:\n", " previous_text += \"\\nUser: \" + user_input + \"\\nBot: \" + generated_text\n", " user_input_widget.value = \"\" \n", " display_user_input_widgets()\n", " send_button.on_click(on_send)\n", " display_user_input_widgets()" ] }, { "cell_type": "code", "execution_count": 3, "id": "8", "metadata": {}, "outputs": [], "source": [ "model_cache = {}\n", "\n", "from ipywidgets import HTML\n", "def interact_with_llm():\n", " models = [\"Writer/camel-5b-hf\", \n", " \"openlm-research/open_llama_3b_v2\",\n", " \"Intel/neural-chat-7b-v3\", \n", " \"Intel/neural-chat-7b-v3-1\", # https://huggingface.co/Intel/neural-chat-7b-v3-1 - checkout the prompting template on the site to get better response.\n", " \"HuggingFaceH4/zephyr-7b-beta\", \n", " \"tiiuae/falcon-7b\"\n", " ]\n", " interaction_modes = [\"Interact with context\", \"Interact without context\"]\n", " model_dropdown = Dropdown(options=models, value=models[0], description=\"Model:\")\n", " interaction_mode = Dropdown(options=interaction_modes, value=interaction_modes[1], description=\"Interaction:\")\n", " temperature_slider = FloatSlider(value=0.71, min=0, max=1, step=0.01, description=\"Temperature:\")\n", " top_p_slider = FloatSlider(value=0.95, min=0, max=1, step=0.01, description=\"Top P:\")\n", " top_k_slider = IntSlider(value=40, min=0, max=100, step=1, description=\"Top K:\")\n", " num_beams_slider = IntSlider(value=3, min=1, max=10, step=1, description=\"Num Beams:\")\n", " repetition_penalty_slider = FloatSlider(value=1.80, min=0, max=2, step=0.1, description=\"Rep Penalty:\")\n", " \n", " out = Output() \n", " left_panel = VBox([model_dropdown, interaction_mode], layout=Layout(margin=\"0px 20px 10px 0px\"))\n", " right_panel = VBox([temperature_slider, top_p_slider, top_k_slider, num_beams_slider, repetition_penalty_slider],\n", " layout=Layout(margin=\"0px 0px 10px 20px\"))\n", " user_input_widgets = HBox([left_panel, right_panel], layout=Layout(margin=\"0px 50px 10px 0px\"))\n", " spinner = HTML(value=\"\")\n", " start_button = Button(description=\"Start Interaction!\", button_style=\"primary\")\n", " start_button_spinner = HBox([start_button, spinner])\n", " start_button_spinner.layout.margin = '0 auto'\n", " display(user_input_widgets)\n", " display(start_button_spinner)\n", " display(out)\n", " \n", " def on_start(button):\n", " start_button.button_style = \"warning\"\n", " start_button.description = \"Loading...\"\n", " spinner.value = \"\"\"\n", "
\n", " \n", " \"\"\"\n", " out.clear_output()\n", " with out:\n", " print(\"\\nSetting up the model, please wait...\")\n", " #out.clear_output()\n", " model_choice = model_dropdown.value\n", " with_context = interaction_mode.value == interaction_modes[0]\n", " temperature = temperature_slider.value\n", " top_p = top_p_slider.value\n", " top_k = top_k_slider.value\n", " num_beams = num_beams_slider.value\n", " repetition_penalty = repetition_penalty_slider.value\n", " model_key = (model_choice, \"cuda\")\n", " if model_key not in model_cache:\n", " model_cache[model_key] = ChatBotModel(model_id_or_path=model_choice)\n", " bot = model_cache[model_key]\n", " #if model_key not in model_cache:\n", " # bot.warmup_model(\n", " # temperature=temperature,\n", " # top_p=top_p,\n", " # top_k=top_k,\n", " # num_beams=num_beams,\n", " # repetition_penalty=repetition_penalty,\n", " # )\n", " \n", " with out:\n", " start_button.button_style = \"success\"\n", " start_button.description = \"Refresh\"\n", " spinner.value = \"\"\n", " print(\"Ready!\")\n", " print(\"\\nNote: This is a demonstration using pretrained models which were not fine-tuned for chat.\")\n", " print(\"If the bot doesn't respond, try clicking on refresh.\\n\")\n", " try:\n", " with out:\n", " bot.interact(\n", " with_context=with_context,\n", " out=out,\n", " temperature=temperature,\n", " top_p=top_p,\n", " top_k=top_k,\n", " num_beams=num_beams,\n", " repetition_penalty=repetition_penalty,\n", " )\n", " except Exception as e:\n", " with out:\n", " print(f\"An error occurred: {e}\")\n", "\n", " start_button.on_click(on_start)\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "10", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2918949e6bed4665ab7559773deeaaae", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(VBox(children=(Dropdown(description='Model:', options=('Writer/camel-5b-hf', 'openlm-research/o…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a7e2e42017224bce906b10767be0f7e5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(Button(button_style='primary', description='Start Interaction!', style=ButtonStyle()), HTML(val…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "15ca2cc52ab0496ba14c4ffcd18b36af", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "89cbe0215a314c4e9680e741a9b32d83", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/748 [00:00