Upload lstm quantized.ipynb
Browse files- lstm quantized.ipynb +410 -0
lstm quantized.ipynb
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": []
|
| 7 |
+
},
|
| 8 |
+
"kernelspec": {
|
| 9 |
+
"name": "python3",
|
| 10 |
+
"display_name": "Python 3"
|
| 11 |
+
},
|
| 12 |
+
"language_info": {
|
| 13 |
+
"name": "python"
|
| 14 |
+
}
|
| 15 |
+
},
|
| 16 |
+
"cells": [
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": 1,
|
| 20 |
+
"metadata": {
|
| 21 |
+
"colab": {
|
| 22 |
+
"base_uri": "https://localhost:8080/"
|
| 23 |
+
},
|
| 24 |
+
"id": "rjkq52idusqZ",
|
| 25 |
+
"outputId": "0248dac8-b344-464e-e759-be72db552717",
|
| 26 |
+
"collapsed": true
|
| 27 |
+
},
|
| 28 |
+
"outputs": [
|
| 29 |
+
{
|
| 30 |
+
"output_type": "stream",
|
| 31 |
+
"name": "stdout",
|
| 32 |
+
"text": [
|
| 33 |
+
"Found existing installation: torch 2.7.1+cpu\n",
|
| 34 |
+
"Uninstalling torch-2.7.1+cpu:\n",
|
| 35 |
+
" Successfully uninstalled torch-2.7.1+cpu\n",
|
| 36 |
+
"\u001b[33mWARNING: Skipping torchtext as it is not installed.\u001b[0m\u001b[33m\n",
|
| 37 |
+
"\u001b[0mLooking in indexes: https://download.pytorch.org/whl/cpu\n",
|
| 38 |
+
"Collecting torch\n",
|
| 39 |
+
" Using cached https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (27 kB)\n",
|
| 40 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n",
|
| 41 |
+
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n",
|
| 42 |
+
"Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.3)\n",
|
| 43 |
+
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n",
|
| 44 |
+
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n",
|
| 45 |
+
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.3.2)\n",
|
| 46 |
+
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy>=1.13.3->torch) (1.3.0)\n",
|
| 47 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
|
| 48 |
+
"Using cached https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl (176.0 MB)\n",
|
| 49 |
+
"Installing collected packages: torch\n",
|
| 50 |
+
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
|
| 51 |
+
"torchaudio 2.6.0+cu124 requires torch==2.6.0, but you have torch 2.7.1+cpu which is incompatible.\n",
|
| 52 |
+
"fastai 2.7.19 requires torch<2.7,>=1.10, but you have torch 2.7.1+cpu which is incompatible.\n",
|
| 53 |
+
"torchvision 0.21.0+cu124 requires torch==2.6.0, but you have torch 2.7.1+cpu which is incompatible.\u001b[0m\u001b[31m\n",
|
| 54 |
+
"\u001b[0mSuccessfully installed torch-2.7.1+cpu\n"
|
| 55 |
+
]
|
| 56 |
+
}
|
| 57 |
+
],
|
| 58 |
+
"source": [
|
| 59 |
+
"!pip uninstall torch torchtext -y\n",
|
| 60 |
+
"!pip install torch --index-url https://download.pytorch.org/whl/cpu"
|
| 61 |
+
]
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"cell_type": "code",
|
| 65 |
+
"source": [
|
| 66 |
+
"!pip install torchtext --index-url https://download.pytorch.org/whl/cu118\n",
|
| 67 |
+
"!pip install 'portalocker>=2.0.0'\n",
|
| 68 |
+
"!pip install 'numpy<2'"
|
| 69 |
+
],
|
| 70 |
+
"metadata": {
|
| 71 |
+
"colab": {
|
| 72 |
+
"base_uri": "https://localhost:8080/"
|
| 73 |
+
},
|
| 74 |
+
"id": "1OtHOlKxO-UI",
|
| 75 |
+
"outputId": "7747d4cd-013d-470a-cdd7-5b346922bf8b"
|
| 76 |
+
},
|
| 77 |
+
"execution_count": 1,
|
| 78 |
+
"outputs": [
|
| 79 |
+
{
|
| 80 |
+
"output_type": "stream",
|
| 81 |
+
"name": "stdout",
|
| 82 |
+
"text": [
|
| 83 |
+
"Looking in indexes: https://download.pytorch.org/whl/cu118\n",
|
| 84 |
+
"Collecting torchtext\n",
|
| 85 |
+
" Using cached https://download.pytorch.org/whl/torchtext-0.17.0%2Bcpu-cp311-cp311-linux_x86_64.whl (2.0 MB)\n",
|
| 86 |
+
"Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from torchtext) (4.67.1)\n",
|
| 87 |
+
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from torchtext) (2.32.3)\n",
|
| 88 |
+
"Collecting torch==2.2.0 (from torchtext)\n",
|
| 89 |
+
" Using cached https://download.pytorch.org/whl/cu118/torch-2.2.0%2Bcu118-cp311-cp311-linux_x86_64.whl (811.7 MB)\n",
|
| 90 |
+
"Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchtext) (1.26.4)\n",
|
| 91 |
+
"Collecting torchdata==0.7.1 (from torchtext)\n",
|
| 92 |
+
" Using cached https://download.pytorch.org/whl/torchdata-0.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)\n",
|
| 93 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (3.18.0)\n",
|
| 94 |
+
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (4.14.0)\n",
|
| 95 |
+
"Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (1.13.3)\n",
|
| 96 |
+
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (3.5)\n",
|
| 97 |
+
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (3.1.6)\n",
|
| 98 |
+
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (2025.3.2)\n",
|
| 99 |
+
"Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch==2.2.0->torchtext)\n",
|
| 100 |
+
" Using cached https://download.pytorch.org/whl/cu118/nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (23.2 MB)\n",
|
| 101 |
+
"Collecting nvidia-cuda-runtime-cu11==11.8.89 (from torch==2.2.0->torchtext)\n",
|
| 102 |
+
" Using cached https://download.pytorch.org/whl/cu118/nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (875 kB)\n",
|
| 103 |
+
"Collecting nvidia-cuda-cupti-cu11==11.8.87 (from torch==2.2.0->torchtext)\n",
|
| 104 |
+
" Using cached https://download.pytorch.org/whl/cu118/nvidia_cuda_cupti_cu11-11.8.87-py3-none-manylinux1_x86_64.whl (13.1 MB)\n",
|
| 105 |
+
"Collecting nvidia-cudnn-cu11==8.7.0.84 (from torch==2.2.0->torchtext)\n",
|
| 106 |
+
" Using cached https://download.pytorch.org/whl/cu118/nvidia_cudnn_cu11-8.7.0.84-py3-none-manylinux1_x86_64.whl (728.5 MB)\n",
|
| 107 |
+
"Collecting nvidia-cublas-cu11==11.11.3.6 (from torch==2.2.0->torchtext)\n",
|
| 108 |
+
" Downloading https://download.pytorch.org/whl/cu118/nvidia_cublas_cu11-11.11.3.6-py3-none-manylinux1_x86_64.whl (417.9 MB)\n",
|
| 109 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m417.9/417.9 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 110 |
+
"\u001b[?25hCollecting nvidia-cufft-cu11==10.9.0.58 (from torch==2.2.0->torchtext)\n",
|
| 111 |
+
" Downloading https://download.pytorch.org/whl/cu118/nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl (168.4 MB)\n",
|
| 112 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.4/168.4 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 113 |
+
"\u001b[?25hCollecting nvidia-curand-cu11==10.3.0.86 (from torch==2.2.0->torchtext)\n",
|
| 114 |
+
" Downloading https://download.pytorch.org/whl/cu118/nvidia_curand_cu11-10.3.0.86-py3-none-manylinux1_x86_64.whl (58.1 MB)\n",
|
| 115 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.1/58.1 MB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 116 |
+
"\u001b[?25hCollecting nvidia-cusolver-cu11==11.4.1.48 (from torch==2.2.0->torchtext)\n",
|
| 117 |
+
" Downloading https://download.pytorch.org/whl/cu118/nvidia_cusolver_cu11-11.4.1.48-py3-none-manylinux1_x86_64.whl (128.2 MB)\n",
|
| 118 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m128.2/128.2 MB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 119 |
+
"\u001b[?25hCollecting nvidia-cusparse-cu11==11.7.5.86 (from torch==2.2.0->torchtext)\n",
|
| 120 |
+
" Downloading https://download.pytorch.org/whl/cu118/nvidia_cusparse_cu11-11.7.5.86-py3-none-manylinux1_x86_64.whl (204.1 MB)\n",
|
| 121 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m204.1/204.1 MB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 122 |
+
"\u001b[?25hCollecting nvidia-nccl-cu11==2.19.3 (from torch==2.2.0->torchtext)\n",
|
| 123 |
+
" Downloading https://download.pytorch.org/whl/cu118/nvidia_nccl_cu11-2.19.3-py3-none-manylinux1_x86_64.whl (135.3 MB)\n",
|
| 124 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m135.3/135.3 MB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 125 |
+
"\u001b[?25hCollecting nvidia-nvtx-cu11==11.8.86 (from torch==2.2.0->torchtext)\n",
|
| 126 |
+
" Downloading https://download.pytorch.org/whl/cu118/nvidia_nvtx_cu11-11.8.86-py3-none-manylinux1_x86_64.whl (99 kB)\n",
|
| 127 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 128 |
+
"\u001b[?25hCollecting triton==2.2.0 (from torch==2.2.0->torchtext)\n",
|
| 129 |
+
" Downloading https://download.pytorch.org/whl/triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (167.9 MB)\n",
|
| 130 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m167.9/167.9 MB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
| 131 |
+
"\u001b[?25hRequirement already satisfied: urllib3>=1.25 in /usr/local/lib/python3.11/dist-packages (from torchdata==0.7.1->torchtext) (2.4.0)\n",
|
| 132 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->torchtext) (3.4.2)\n",
|
| 133 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->torchtext) (3.10)\n",
|
| 134 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->torchtext) (2025.6.15)\n",
|
| 135 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch==2.2.0->torchtext) (3.0.2)\n",
|
| 136 |
+
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy->torch==2.2.0->torchtext) (1.3.0)\n",
|
| 137 |
+
"Installing collected packages: triton, nvidia-nvtx-cu11, nvidia-nccl-cu11, nvidia-cusparse-cu11, nvidia-curand-cu11, nvidia-cufft-cu11, nvidia-cuda-runtime-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-cupti-cu11, nvidia-cublas-cu11, nvidia-cusolver-cu11, nvidia-cudnn-cu11, torch, torchdata, torchtext\n",
|
| 138 |
+
" Attempting uninstall: triton\n",
|
| 139 |
+
" Found existing installation: triton 3.2.0\n",
|
| 140 |
+
" Uninstalling triton-3.2.0:\n",
|
| 141 |
+
" Successfully uninstalled triton-3.2.0\n",
|
| 142 |
+
" Attempting uninstall: torch\n",
|
| 143 |
+
" Found existing installation: torch 2.7.1+cpu\n",
|
| 144 |
+
" Uninstalling torch-2.7.1+cpu:\n",
|
| 145 |
+
" Successfully uninstalled torch-2.7.1+cpu\n",
|
| 146 |
+
" Attempting uninstall: torchdata\n",
|
| 147 |
+
" Found existing installation: torchdata 0.11.0\n",
|
| 148 |
+
" Uninstalling torchdata-0.11.0:\n",
|
| 149 |
+
" Successfully uninstalled torchdata-0.11.0\n",
|
| 150 |
+
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
|
| 151 |
+
"torchaudio 2.6.0+cu124 requires torch==2.6.0, but you have torch 2.2.0+cu118 which is incompatible.\n",
|
| 152 |
+
"torchvision 0.21.0+cu124 requires torch==2.6.0, but you have torch 2.2.0+cu118 which is incompatible.\n",
|
| 153 |
+
"torchtune 0.6.1 requires torchdata==0.11.0, but you have torchdata 0.7.1 which is incompatible.\u001b[0m\u001b[31m\n",
|
| 154 |
+
"\u001b[0mSuccessfully installed nvidia-cublas-cu11-11.11.3.6 nvidia-cuda-cupti-cu11-11.8.87 nvidia-cuda-nvrtc-cu11-11.8.89 nvidia-cuda-runtime-cu11-11.8.89 nvidia-cudnn-cu11-8.7.0.84 nvidia-cufft-cu11-10.9.0.58 nvidia-curand-cu11-10.3.0.86 nvidia-cusolver-cu11-11.4.1.48 nvidia-cusparse-cu11-11.7.5.86 nvidia-nccl-cu11-2.19.3 nvidia-nvtx-cu11-11.8.86 torch-2.2.0+cu118 torchdata-0.7.1 torchtext-0.17.0+cpu triton-2.2.0\n",
|
| 155 |
+
"Requirement already satisfied: portalocker>=2.0.0 in /usr/local/lib/python3.11/dist-packages (3.2.0)\n",
|
| 156 |
+
"Requirement already satisfied: numpy<2 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n"
|
| 157 |
+
]
|
| 158 |
+
}
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"cell_type": "code",
|
| 163 |
+
"execution_count": 21,
|
| 164 |
+
"metadata": {
|
| 165 |
+
"id": "0l_UBDarnXHM"
|
| 166 |
+
},
|
| 167 |
+
"outputs": [],
|
| 168 |
+
"source": [
|
| 169 |
+
"import torch\n",
|
| 170 |
+
"import torch.nn as nn\n",
|
| 171 |
+
"from huggingface_hub import hf_hub_download\n",
|
| 172 |
+
"from torchtext.datasets import IMDB\n",
|
| 173 |
+
"from torchtext.data.utils import get_tokenizer\n",
|
| 174 |
+
"from torch.nn.utils.rnn import pad_sequence # For padding\n",
|
| 175 |
+
"# import torch.nn.functional as F # For softmax and multinomial sampling\n",
|
| 176 |
+
"import warnings\n",
|
| 177 |
+
"warnings.filterwarnings(\"ignore\")"
|
| 178 |
+
]
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"cell_type": "code",
|
| 182 |
+
"source": [
|
| 183 |
+
"# --- 0. Setup Global Variables and Special Tokens ---\n",
|
| 184 |
+
"# Define special tokens and their indices\n",
|
| 185 |
+
"UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3\n",
|
| 186 |
+
"special_tokens = ['<unk>', '<pad>', '<bos>', '<eos>']"
|
| 187 |
+
],
|
| 188 |
+
"metadata": {
|
| 189 |
+
"id": "Jo-Wq6FN_6zh"
|
| 190 |
+
},
|
| 191 |
+
"execution_count": 3,
|
| 192 |
+
"outputs": []
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"cell_type": "code",
|
| 196 |
+
"source": [
|
| 197 |
+
"train_iter, test_iter = IMDB(split=('train', 'test'))\n",
|
| 198 |
+
"tokenizer = get_tokenizer('basic_english')\n",
|
| 199 |
+
"\n",
|
| 200 |
+
"def yield_tokens(data_iter):\n",
|
| 201 |
+
" for _, text in data_iter:\n",
|
| 202 |
+
" yield tokenizer(text)"
|
| 203 |
+
],
|
| 204 |
+
"metadata": {
|
| 205 |
+
"id": "CsUIMjfsQ7Rn"
|
| 206 |
+
},
|
| 207 |
+
"execution_count": 18,
|
| 208 |
+
"outputs": []
|
| 209 |
+
},
|
| 210 |
+
{
|
| 211 |
+
"cell_type": "code",
|
| 212 |
+
"execution_count": 10,
|
| 213 |
+
"metadata": {
|
| 214 |
+
"id": "1DUKTDjHuo-t"
|
| 215 |
+
},
|
| 216 |
+
"outputs": [],
|
| 217 |
+
"source": [
|
| 218 |
+
"# --- 2. Model Definition (Text Generator) ---\n",
|
| 219 |
+
"class TextGenerator(nn.Module):\n",
|
| 220 |
+
" def __init__(self, vocab_size, embed_dim, hidden_dim):\n",
|
| 221 |
+
" super().__init__()\n",
|
| 222 |
+
" # Embedding layer: Converts token IDs to dense vectors\n",
|
| 223 |
+
" # `padding_idx` ensures that PAD tokens are ignored (zeroed out)\n",
|
| 224 |
+
" self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_IDX)\n",
|
| 225 |
+
" # LSTM layer: Processes sequences. `batch_first=True` matches our (batch_size, seq_len) input\n",
|
| 226 |
+
" self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)\n",
|
| 227 |
+
" # Linear layer: Maps LSTM output to vocabulary size (logits for next token prediction)\n",
|
| 228 |
+
" self.fc = nn.Linear(hidden_dim, vocab_size)\n",
|
| 229 |
+
" self.init_weights()\n",
|
| 230 |
+
" self.hidden_dim = hidden_dim # Store hidden dimension for potentially initializing hidden states\n",
|
| 231 |
+
"\n",
|
| 232 |
+
" def init_weights(self):\n",
|
| 233 |
+
" # Initialize weights with a uniform distribution for better training stability\n",
|
| 234 |
+
" initrange = 0.1\n",
|
| 235 |
+
" self.embedding.weight.data.uniform_(-initrange, initrange)\n",
|
| 236 |
+
" self.fc.weight.data.uniform_(-initrange, initrange)\n",
|
| 237 |
+
" self.fc.bias.data.zero_()\n",
|
| 238 |
+
" # LSTM weights are often initialized by PyTorch's defaults, or more sophisticated methods.\n",
|
| 239 |
+
"\n",
|
| 240 |
+
" def forward(self, text, hidden=None):\n",
|
| 241 |
+
" # `text` shape: (batch_size, seq_len)\n",
|
| 242 |
+
" embedded = self.embedding(text) # Output shape: (batch_size, seq_len, embed_dim)\n",
|
| 243 |
+
" # Pass embedded sequence through LSTM.\n",
|
| 244 |
+
" # `hidden` can be passed for sequential inference (e.g., generating token by token).\n",
|
| 245 |
+
" output, hidden = self.lstm(embedded, hidden) # `output` shape: (batch_size, seq_len, hidden_dim)\n",
|
| 246 |
+
" # Apply linear layer to each time step's LSTM output\n",
|
| 247 |
+
" output = self.fc(output) # Output shape: (batch_size, seq_len, vocab_size) - logits for each token in sequence\n",
|
| 248 |
+
" return output, hidden # Return logits and the final hidden state"
|
| 249 |
+
]
|
| 250 |
+
},
|
| 251 |
+
{
|
| 252 |
+
"cell_type": "code",
|
| 253 |
+
"execution_count": 5,
|
| 254 |
+
"metadata": {
|
| 255 |
+
"id": "dDQyS6TZu17f",
|
| 256 |
+
"colab": {
|
| 257 |
+
"base_uri": "https://localhost:8080/"
|
| 258 |
+
},
|
| 259 |
+
"outputId": "cfe6963e-c2a5-49f1-dff2-1c2835dfa777"
|
| 260 |
+
},
|
| 261 |
+
"outputs": [
|
| 262 |
+
{
|
| 263 |
+
"output_type": "stream",
|
| 264 |
+
"name": "stdout",
|
| 265 |
+
"text": [
|
| 266 |
+
"\n",
|
| 267 |
+
"Example text generation:\n"
|
| 268 |
+
]
|
| 269 |
+
}
|
| 270 |
+
],
|
| 271 |
+
"source": [
|
| 272 |
+
"# --- 6. Text Generation Example ---\n",
|
| 273 |
+
"print(\"\\nExample text generation:\")\n",
|
| 274 |
+
"\n",
|
| 275 |
+
"def generate_text(model, vocab, start_text, max_len=50, temperature=0.8):\n",
|
| 276 |
+
" model.eval() # Set model to evaluation mode\n",
|
| 277 |
+
" # Convert starting text to token IDs, prepending BOS\n",
|
| 278 |
+
" input_ids = [BOS_IDX] + text_pipeline(start_text)\n",
|
| 279 |
+
" generated_ids = list(input_ids)\n",
|
| 280 |
+
"\n",
|
| 281 |
+
" # Initialize LSTM's hidden state (h_0, c_0) to None\n",
|
| 282 |
+
" hidden = None\n",
|
| 283 |
+
" model_device = next(model.parameters()).device\n",
|
| 284 |
+
"\n",
|
| 285 |
+
" with torch.no_grad():\n",
|
| 286 |
+
" for _ in range(max_len):\n",
|
| 287 |
+
" # For generation, feed only the *last* generated token as input\n",
|
| 288 |
+
" # This is crucial for autoregressive generation\n",
|
| 289 |
+
" current_input_tensor = torch.tensor([[generated_ids[-1]]], dtype=torch.long).to(model_device) # Shape (1, 1)\n",
|
| 290 |
+
"\n",
|
| 291 |
+
" # Pass the single token and the current hidden state to the model\n",
|
| 292 |
+
" output_logits, hidden = model(current_input_tensor, hidden)\n",
|
| 293 |
+
"\n",
|
| 294 |
+
" # Apply temperature to logits for creativity/randomness\n",
|
| 295 |
+
" # We care about the prediction for the single token in `current_input_tensor`\n",
|
| 296 |
+
" prediction_logits = output_logits[:, -1, :] / temperature\n",
|
| 297 |
+
" probabilities = F.softmax(prediction_logits, dim=-1) # Convert logits to probabilities\n",
|
| 298 |
+
"\n",
|
| 299 |
+
" # Sample the next token from the probability distribution\n",
|
| 300 |
+
" next_token_id = torch.multinomial(probabilities, num_samples=1).item()\n",
|
| 301 |
+
"\n",
|
| 302 |
+
" generated_ids.append(next_token_id) # Add the sampled token to the generated sequence\n",
|
| 303 |
+
"\n",
|
| 304 |
+
" # Stop generation if EOS token is predicted\n",
|
| 305 |
+
" if next_token_id == EOS_IDX:\n",
|
| 306 |
+
" break\n",
|
| 307 |
+
"\n",
|
| 308 |
+
" # Convert generated token IDs back to human-readable text\n",
|
| 309 |
+
" generated_text = ' '.join(vocab.lookup_tokens(generated_ids))\n",
|
| 310 |
+
" # Clean up special tokens for display\n",
|
| 311 |
+
" generated_text = generated_text.replace(vocab.lookup_token(BOS_IDX), '')\n",
|
| 312 |
+
" generated_text = generated_text.replace(vocab.lookup_token(EOS_IDX), '')\n",
|
| 313 |
+
" generated_text = generated_text.replace(vocab.lookup_token(PAD_IDX), '')\n",
|
| 314 |
+
" return ' '.join(generated_text.split()) # Remove any extra spaces caused by token replacement"
|
| 315 |
+
]
|
| 316 |
+
},
|
| 317 |
+
{
|
| 318 |
+
"cell_type": "code",
|
| 319 |
+
"source": [
|
| 320 |
+
"# quantized_model_loaded = torch.load(\"model_quant_dynamic.pth\", map_location='cpu')\n",
|
| 321 |
+
"model_path = hf_hub_download(\"wbmlr/model_quant_dynamic\", \"model_quant_dynamic.pth\")\n",
|
| 322 |
+
"quantized_model_loaded = torch.load(model_path, map_location='cpu',weights_only=False)\n",
|
| 323 |
+
"vocab_path = hf_hub_download(\"wbmlr/model_quant_dynamic\", \"vocab.pth\")\n",
|
| 324 |
+
"vocab = torch.load(vocab_path,weights_only=False)"
|
| 325 |
+
],
|
| 326 |
+
"metadata": {
|
| 327 |
+
"id": "hYIV5m3YOqvg"
|
| 328 |
+
},
|
| 329 |
+
"execution_count": 11,
|
| 330 |
+
"outputs": []
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"cell_type": "code",
|
| 334 |
+
"source": [
|
| 335 |
+
"# Text processing pipeline: converts raw text string to a list of token IDs\n",
|
| 336 |
+
"def text_pipeline(text):\n",
|
| 337 |
+
" return vocab(tokenizer(text))"
|
| 338 |
+
],
|
| 339 |
+
"metadata": {
|
| 340 |
+
"id": "-FD9pWNHQtht"
|
| 341 |
+
},
|
| 342 |
+
"execution_count": 14,
|
| 343 |
+
"outputs": []
|
| 344 |
+
},
|
| 345 |
+
{
|
| 346 |
+
"cell_type": "code",
|
| 347 |
+
"source": [
|
| 348 |
+
"quantized_model_loaded.eval()"
|
| 349 |
+
],
|
| 350 |
+
"metadata": {
|
| 351 |
+
"colab": {
|
| 352 |
+
"base_uri": "https://localhost:8080/"
|
| 353 |
+
},
|
| 354 |
+
"id": "3qm7XSznP2a5",
|
| 355 |
+
"outputId": "4d2b3e96-bf5a-4341-82e8-90b7546e2a71"
|
| 356 |
+
},
|
| 357 |
+
"execution_count": 12,
|
| 358 |
+
"outputs": [
|
| 359 |
+
{
|
| 360 |
+
"output_type": "execute_result",
|
| 361 |
+
"data": {
|
| 362 |
+
"text/plain": [
|
| 363 |
+
"TextGenerator(\n",
|
| 364 |
+
" (embedding): Embedding(100686, 8, padding_idx=1)\n",
|
| 365 |
+
" (lstm): DynamicQuantizedLSTM(8, 16, batch_first=True)\n",
|
| 366 |
+
" (fc): DynamicQuantizedLinear(in_features=16, out_features=100686, dtype=torch.qint8, qscheme=torch.per_tensor_affine)\n",
|
| 367 |
+
")"
|
| 368 |
+
]
|
| 369 |
+
},
|
| 370 |
+
"metadata": {},
|
| 371 |
+
"execution_count": 12
|
| 372 |
+
}
|
| 373 |
+
]
|
| 374 |
+
},
|
| 375 |
+
{
|
| 376 |
+
"cell_type": "code",
|
| 377 |
+
"source": [
|
| 378 |
+
"start_prompt = \"The movie is\"\n",
|
| 379 |
+
"quant_text = generate_text(quantized_model_loaded, vocab, start_prompt)\n",
|
| 380 |
+
"print(f\"Quantized Generated: {quant_text}\")"
|
| 381 |
+
],
|
| 382 |
+
"metadata": {
|
| 383 |
+
"id": "RWg2FcQDSHF4",
|
| 384 |
+
"colab": {
|
| 385 |
+
"base_uri": "https://localhost:8080/"
|
| 386 |
+
},
|
| 387 |
+
"outputId": "1f938d07-f2a5-4ae9-e65b-fc90c82aff13"
|
| 388 |
+
},
|
| 389 |
+
"execution_count": 20,
|
| 390 |
+
"outputs": [
|
| 391 |
+
{
|
| 392 |
+
"output_type": "stream",
|
| 393 |
+
"name": "stdout",
|
| 394 |
+
"text": [
|
| 395 |
+
"Quantized Generated: the movie is devil many can it are ! a is , the the it\n"
|
| 396 |
+
]
|
| 397 |
+
}
|
| 398 |
+
]
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
"cell_type": "code",
|
| 402 |
+
"execution_count": null,
|
| 403 |
+
"metadata": {
|
| 404 |
+
"id": "Ld1eYtwUGLs5"
|
| 405 |
+
},
|
| 406 |
+
"outputs": [],
|
| 407 |
+
"source": []
|
| 408 |
+
}
|
| 409 |
+
]
|
| 410 |
+
}
|