{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rjkq52idusqZ", "outputId": "0248dac8-b344-464e-e759-be72db552717", "collapsed": true }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Found existing installation: torch 2.7.1+cpu\n", "Uninstalling torch-2.7.1+cpu:\n", " Successfully uninstalled torch-2.7.1+cpu\n", "\u001b[33mWARNING: Skipping torchtext as it is not installed.\u001b[0m\u001b[33m\n", "\u001b[0mLooking in indexes: https://download.pytorch.org/whl/cpu\n", "Collecting torch\n", " Using cached https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (27 kB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n", "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.3)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.3.2)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy>=1.13.3->torch) (1.3.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n", "Using cached https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl (176.0 MB)\n", "Installing collected packages: torch\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "torchaudio 2.6.0+cu124 requires torch==2.6.0, but you have torch 2.7.1+cpu which is incompatible.\n", "fastai 2.7.19 requires torch<2.7,>=1.10, but you have torch 2.7.1+cpu which is incompatible.\n", "torchvision 0.21.0+cu124 requires torch==2.6.0, but you have torch 2.7.1+cpu which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed torch-2.7.1+cpu\n" ] } ], "source": [ "!pip uninstall torch torchtext -y\n", "!pip install torch --index-url https://download.pytorch.org/whl/cpu" ] }, { "cell_type": "code", "source": [ "!pip install torchtext --index-url https://download.pytorch.org/whl/cu118\n", "!pip install 'portalocker>=2.0.0'\n", "!pip install 'numpy<2'" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1OtHOlKxO-UI", "outputId": "7747d4cd-013d-470a-cdd7-5b346922bf8b" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://download.pytorch.org/whl/cu118\n", "Collecting torchtext\n", " Using cached https://download.pytorch.org/whl/torchtext-0.17.0%2Bcpu-cp311-cp311-linux_x86_64.whl (2.0 MB)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from torchtext) (4.67.1)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from torchtext) (2.32.3)\n", "Collecting torch==2.2.0 (from torchtext)\n", " Using cached https://download.pytorch.org/whl/cu118/torch-2.2.0%2Bcu118-cp311-cp311-linux_x86_64.whl (811.7 MB)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchtext) (1.26.4)\n", "Collecting torchdata==0.7.1 (from torchtext)\n", " Using cached https://download.pytorch.org/whl/torchdata-0.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (3.18.0)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (4.14.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (1.13.3)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (3.5)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (3.1.6)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (2025.3.2)\n", "Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch==2.2.0->torchtext)\n", " Using cached https://download.pytorch.org/whl/cu118/nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (23.2 MB)\n", "Collecting nvidia-cuda-runtime-cu11==11.8.89 (from torch==2.2.0->torchtext)\n", " Using cached https://download.pytorch.org/whl/cu118/nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (875 kB)\n", "Collecting nvidia-cuda-cupti-cu11==11.8.87 (from torch==2.2.0->torchtext)\n", " Using cached https://download.pytorch.org/whl/cu118/nvidia_cuda_cupti_cu11-11.8.87-py3-none-manylinux1_x86_64.whl (13.1 MB)\n", "Collecting nvidia-cudnn-cu11==8.7.0.84 (from torch==2.2.0->torchtext)\n", " Using cached https://download.pytorch.org/whl/cu118/nvidia_cudnn_cu11-8.7.0.84-py3-none-manylinux1_x86_64.whl (728.5 MB)\n", "Collecting nvidia-cublas-cu11==11.11.3.6 (from torch==2.2.0->torchtext)\n", " Downloading https://download.pytorch.org/whl/cu118/nvidia_cublas_cu11-11.11.3.6-py3-none-manylinux1_x86_64.whl (417.9 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m417.9/417.9 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting nvidia-cufft-cu11==10.9.0.58 (from torch==2.2.0->torchtext)\n", " Downloading https://download.pytorch.org/whl/cu118/nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl (168.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.4/168.4 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting nvidia-curand-cu11==10.3.0.86 (from torch==2.2.0->torchtext)\n", " Downloading https://download.pytorch.org/whl/cu118/nvidia_curand_cu11-10.3.0.86-py3-none-manylinux1_x86_64.whl (58.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.1/58.1 MB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting nvidia-cusolver-cu11==11.4.1.48 (from torch==2.2.0->torchtext)\n", " Downloading https://download.pytorch.org/whl/cu118/nvidia_cusolver_cu11-11.4.1.48-py3-none-manylinux1_x86_64.whl (128.2 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m128.2/128.2 MB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting nvidia-cusparse-cu11==11.7.5.86 (from torch==2.2.0->torchtext)\n", " Downloading https://download.pytorch.org/whl/cu118/nvidia_cusparse_cu11-11.7.5.86-py3-none-manylinux1_x86_64.whl (204.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m204.1/204.1 MB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting nvidia-nccl-cu11==2.19.3 (from torch==2.2.0->torchtext)\n", " Downloading https://download.pytorch.org/whl/cu118/nvidia_nccl_cu11-2.19.3-py3-none-manylinux1_x86_64.whl (135.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m135.3/135.3 MB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting nvidia-nvtx-cu11==11.8.86 (from torch==2.2.0->torchtext)\n", " Downloading https://download.pytorch.org/whl/cu118/nvidia_nvtx_cu11-11.8.86-py3-none-manylinux1_x86_64.whl (99 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting triton==2.2.0 (from torch==2.2.0->torchtext)\n", " Downloading https://download.pytorch.org/whl/triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (167.9 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m167.9/167.9 MB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: urllib3>=1.25 in /usr/local/lib/python3.11/dist-packages (from torchdata==0.7.1->torchtext) (2.4.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->torchtext) (3.4.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->torchtext) (3.10)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->torchtext) (2025.6.15)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch==2.2.0->torchtext) (3.0.2)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy->torch==2.2.0->torchtext) (1.3.0)\n", "Installing collected packages: triton, nvidia-nvtx-cu11, nvidia-nccl-cu11, nvidia-cusparse-cu11, nvidia-curand-cu11, nvidia-cufft-cu11, nvidia-cuda-runtime-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-cupti-cu11, nvidia-cublas-cu11, nvidia-cusolver-cu11, nvidia-cudnn-cu11, torch, torchdata, torchtext\n", " Attempting uninstall: triton\n", " Found existing installation: triton 3.2.0\n", " Uninstalling triton-3.2.0:\n", " Successfully uninstalled triton-3.2.0\n", " Attempting uninstall: torch\n", " Found existing installation: torch 2.7.1+cpu\n", " Uninstalling torch-2.7.1+cpu:\n", " Successfully uninstalled torch-2.7.1+cpu\n", " Attempting uninstall: torchdata\n", " Found existing installation: torchdata 0.11.0\n", " Uninstalling torchdata-0.11.0:\n", " Successfully uninstalled torchdata-0.11.0\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "torchaudio 2.6.0+cu124 requires torch==2.6.0, but you have torch 2.2.0+cu118 which is incompatible.\n", "torchvision 0.21.0+cu124 requires torch==2.6.0, but you have torch 2.2.0+cu118 which is incompatible.\n", "torchtune 0.6.1 requires torchdata==0.11.0, but you have torchdata 0.7.1 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed nvidia-cublas-cu11-11.11.3.6 nvidia-cuda-cupti-cu11-11.8.87 nvidia-cuda-nvrtc-cu11-11.8.89 nvidia-cuda-runtime-cu11-11.8.89 nvidia-cudnn-cu11-8.7.0.84 nvidia-cufft-cu11-10.9.0.58 nvidia-curand-cu11-10.3.0.86 nvidia-cusolver-cu11-11.4.1.48 nvidia-cusparse-cu11-11.7.5.86 nvidia-nccl-cu11-2.19.3 nvidia-nvtx-cu11-11.8.86 torch-2.2.0+cu118 torchdata-0.7.1 torchtext-0.17.0+cpu triton-2.2.0\n", "Requirement already satisfied: portalocker>=2.0.0 in /usr/local/lib/python3.11/dist-packages (3.2.0)\n", "Requirement already satisfied: numpy<2 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n" ] } ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "0l_UBDarnXHM" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from huggingface_hub import hf_hub_download\n", "from torchtext.datasets import IMDB\n", "from torchtext.data.utils import get_tokenizer\n", "from torch.nn.utils.rnn import pad_sequence # For padding\n", "# import torch.nn.functional as F # For softmax and multinomial sampling\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "source": [ "# --- 0. Setup Global Variables and Special Tokens ---\n", "# Define special tokens and their indices\n", "UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3\n", "special_tokens = ['', '', '', '']" ], "metadata": { "id": "Jo-Wq6FN_6zh" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "train_iter, test_iter = IMDB(split=('train', 'test'))\n", "tokenizer = get_tokenizer('basic_english')\n", "\n", "def yield_tokens(data_iter):\n", " for _, text in data_iter:\n", " yield tokenizer(text)" ], "metadata": { "id": "CsUIMjfsQ7Rn" }, "execution_count": 18, "outputs": [] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "1DUKTDjHuo-t" }, "outputs": [], "source": [ "# --- 2. Model Definition (Text Generator) ---\n", "class TextGenerator(nn.Module):\n", " def __init__(self, vocab_size, embed_dim, hidden_dim):\n", " super().__init__()\n", " # Embedding layer: Converts token IDs to dense vectors\n", " # `padding_idx` ensures that PAD tokens are ignored (zeroed out)\n", " self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_IDX)\n", " # LSTM layer: Processes sequences. `batch_first=True` matches our (batch_size, seq_len) input\n", " self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)\n", " # Linear layer: Maps LSTM output to vocabulary size (logits for next token prediction)\n", " self.fc = nn.Linear(hidden_dim, vocab_size)\n", " self.init_weights()\n", " self.hidden_dim = hidden_dim # Store hidden dimension for potentially initializing hidden states\n", "\n", " def init_weights(self):\n", " # Initialize weights with a uniform distribution for better training stability\n", " initrange = 0.1\n", " self.embedding.weight.data.uniform_(-initrange, initrange)\n", " self.fc.weight.data.uniform_(-initrange, initrange)\n", " self.fc.bias.data.zero_()\n", " # LSTM weights are often initialized by PyTorch's defaults, or more sophisticated methods.\n", "\n", " def forward(self, text, hidden=None):\n", " # `text` shape: (batch_size, seq_len)\n", " embedded = self.embedding(text) # Output shape: (batch_size, seq_len, embed_dim)\n", " # Pass embedded sequence through LSTM.\n", " # `hidden` can be passed for sequential inference (e.g., generating token by token).\n", " output, hidden = self.lstm(embedded, hidden) # `output` shape: (batch_size, seq_len, hidden_dim)\n", " # Apply linear layer to each time step's LSTM output\n", " output = self.fc(output) # Output shape: (batch_size, seq_len, vocab_size) - logits for each token in sequence\n", " return output, hidden # Return logits and the final hidden state" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "dDQyS6TZu17f", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "cfe6963e-c2a5-49f1-dff2-1c2835dfa777" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Example text generation:\n" ] } ], "source": [ "# --- 6. Text Generation Example ---\n", "print(\"\\nExample text generation:\")\n", "\n", "def generate_text(model, vocab, start_text, max_len=50, temperature=0.8):\n", " model.eval() # Set model to evaluation mode\n", " # Convert starting text to token IDs, prepending BOS\n", " input_ids = [BOS_IDX] + text_pipeline(start_text)\n", " generated_ids = list(input_ids)\n", "\n", " # Initialize LSTM's hidden state (h_0, c_0) to None\n", " hidden = None\n", " model_device = next(model.parameters()).device\n", "\n", " with torch.no_grad():\n", " for _ in range(max_len):\n", " # For generation, feed only the *last* generated token as input\n", " # This is crucial for autoregressive generation\n", " current_input_tensor = torch.tensor([[generated_ids[-1]]], dtype=torch.long).to(model_device) # Shape (1, 1)\n", "\n", " # Pass the single token and the current hidden state to the model\n", " output_logits, hidden = model(current_input_tensor, hidden)\n", "\n", " # Apply temperature to logits for creativity/randomness\n", " # We care about the prediction for the single token in `current_input_tensor`\n", " prediction_logits = output_logits[:, -1, :] / temperature\n", " probabilities = F.softmax(prediction_logits, dim=-1) # Convert logits to probabilities\n", "\n", " # Sample the next token from the probability distribution\n", " next_token_id = torch.multinomial(probabilities, num_samples=1).item()\n", "\n", " generated_ids.append(next_token_id) # Add the sampled token to the generated sequence\n", "\n", " # Stop generation if EOS token is predicted\n", " if next_token_id == EOS_IDX:\n", " break\n", "\n", " # Convert generated token IDs back to human-readable text\n", " generated_text = ' '.join(vocab.lookup_tokens(generated_ids))\n", " # Clean up special tokens for display\n", " generated_text = generated_text.replace(vocab.lookup_token(BOS_IDX), '')\n", " generated_text = generated_text.replace(vocab.lookup_token(EOS_IDX), '')\n", " generated_text = generated_text.replace(vocab.lookup_token(PAD_IDX), '')\n", " return ' '.join(generated_text.split()) # Remove any extra spaces caused by token replacement" ] }, { "cell_type": "code", "source": [ "# quantized_model_loaded = torch.load(\"model_quant_dynamic.pth\", map_location='cpu')\n", "model_path = hf_hub_download(\"wbmlr/model_quant_dynamic\", \"model_quant_dynamic.pth\")\n", "quantized_model_loaded = torch.load(model_path, map_location='cpu',weights_only=False)\n", "vocab_path = hf_hub_download(\"wbmlr/model_quant_dynamic\", \"vocab.pth\")\n", "vocab = torch.load(vocab_path,weights_only=False)" ], "metadata": { "id": "hYIV5m3YOqvg" }, "execution_count": 11, "outputs": [] }, { "cell_type": "code", "source": [ "# Text processing pipeline: converts raw text string to a list of token IDs\n", "def text_pipeline(text):\n", " return vocab(tokenizer(text))" ], "metadata": { "id": "-FD9pWNHQtht" }, "execution_count": 14, "outputs": [] }, { "cell_type": "code", "source": [ "quantized_model_loaded.eval()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3qm7XSznP2a5", "outputId": "4d2b3e96-bf5a-4341-82e8-90b7546e2a71" }, "execution_count": 12, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "TextGenerator(\n", " (embedding): Embedding(100686, 8, padding_idx=1)\n", " (lstm): DynamicQuantizedLSTM(8, 16, batch_first=True)\n", " (fc): DynamicQuantizedLinear(in_features=16, out_features=100686, dtype=torch.qint8, qscheme=torch.per_tensor_affine)\n", ")" ] }, "metadata": {}, "execution_count": 12 } ] }, { "cell_type": "code", "source": [ "start_prompt = \"The movie is\"\n", "quant_text = generate_text(quantized_model_loaded, vocab, start_prompt)\n", "print(f\"Quantized Generated: {quant_text}\")" ], "metadata": { "id": "RWg2FcQDSHF4", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "1f938d07-f2a5-4ae9-e65b-fc90c82aff13" }, "execution_count": 20, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Quantized Generated: the movie is devil many can it are ! a is , the the it\n" ] } ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ld1eYtwUGLs5" }, "outputs": [], "source": [] } ] }