wbmlr commited on
Commit
34cf519
·
verified ·
1 Parent(s): 1dd0b4d

Upload lstm quantized.ipynb

Browse files
Files changed (1) hide show
  1. lstm quantized.ipynb +410 -0
lstm quantized.ipynb ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 1,
20
+ "metadata": {
21
+ "colab": {
22
+ "base_uri": "https://localhost:8080/"
23
+ },
24
+ "id": "rjkq52idusqZ",
25
+ "outputId": "0248dac8-b344-464e-e759-be72db552717",
26
+ "collapsed": true
27
+ },
28
+ "outputs": [
29
+ {
30
+ "output_type": "stream",
31
+ "name": "stdout",
32
+ "text": [
33
+ "Found existing installation: torch 2.7.1+cpu\n",
34
+ "Uninstalling torch-2.7.1+cpu:\n",
35
+ " Successfully uninstalled torch-2.7.1+cpu\n",
36
+ "\u001b[33mWARNING: Skipping torchtext as it is not installed.\u001b[0m\u001b[33m\n",
37
+ "\u001b[0mLooking in indexes: https://download.pytorch.org/whl/cpu\n",
38
+ "Collecting torch\n",
39
+ " Using cached https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (27 kB)\n",
40
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n",
41
+ "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n",
42
+ "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.3)\n",
43
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n",
44
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n",
45
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.3.2)\n",
46
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy>=1.13.3->torch) (1.3.0)\n",
47
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
48
+ "Using cached https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl (176.0 MB)\n",
49
+ "Installing collected packages: torch\n",
50
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
51
+ "torchaudio 2.6.0+cu124 requires torch==2.6.0, but you have torch 2.7.1+cpu which is incompatible.\n",
52
+ "fastai 2.7.19 requires torch<2.7,>=1.10, but you have torch 2.7.1+cpu which is incompatible.\n",
53
+ "torchvision 0.21.0+cu124 requires torch==2.6.0, but you have torch 2.7.1+cpu which is incompatible.\u001b[0m\u001b[31m\n",
54
+ "\u001b[0mSuccessfully installed torch-2.7.1+cpu\n"
55
+ ]
56
+ }
57
+ ],
58
+ "source": [
59
+ "!pip uninstall torch torchtext -y\n",
60
+ "!pip install torch --index-url https://download.pytorch.org/whl/cpu"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "source": [
66
+ "!pip install torchtext --index-url https://download.pytorch.org/whl/cu118\n",
67
+ "!pip install 'portalocker>=2.0.0'\n",
68
+ "!pip install 'numpy<2'"
69
+ ],
70
+ "metadata": {
71
+ "colab": {
72
+ "base_uri": "https://localhost:8080/"
73
+ },
74
+ "id": "1OtHOlKxO-UI",
75
+ "outputId": "7747d4cd-013d-470a-cdd7-5b346922bf8b"
76
+ },
77
+ "execution_count": 1,
78
+ "outputs": [
79
+ {
80
+ "output_type": "stream",
81
+ "name": "stdout",
82
+ "text": [
83
+ "Looking in indexes: https://download.pytorch.org/whl/cu118\n",
84
+ "Collecting torchtext\n",
85
+ " Using cached https://download.pytorch.org/whl/torchtext-0.17.0%2Bcpu-cp311-cp311-linux_x86_64.whl (2.0 MB)\n",
86
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from torchtext) (4.67.1)\n",
87
+ "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from torchtext) (2.32.3)\n",
88
+ "Collecting torch==2.2.0 (from torchtext)\n",
89
+ " Using cached https://download.pytorch.org/whl/cu118/torch-2.2.0%2Bcu118-cp311-cp311-linux_x86_64.whl (811.7 MB)\n",
90
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchtext) (1.26.4)\n",
91
+ "Collecting torchdata==0.7.1 (from torchtext)\n",
92
+ " Using cached https://download.pytorch.org/whl/torchdata-0.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)\n",
93
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (3.18.0)\n",
94
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (4.14.0)\n",
95
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (1.13.3)\n",
96
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (3.5)\n",
97
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (3.1.6)\n",
98
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch==2.2.0->torchtext) (2025.3.2)\n",
99
+ "Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch==2.2.0->torchtext)\n",
100
+ " Using cached https://download.pytorch.org/whl/cu118/nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (23.2 MB)\n",
101
+ "Collecting nvidia-cuda-runtime-cu11==11.8.89 (from torch==2.2.0->torchtext)\n",
102
+ " Using cached https://download.pytorch.org/whl/cu118/nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (875 kB)\n",
103
+ "Collecting nvidia-cuda-cupti-cu11==11.8.87 (from torch==2.2.0->torchtext)\n",
104
+ " Using cached https://download.pytorch.org/whl/cu118/nvidia_cuda_cupti_cu11-11.8.87-py3-none-manylinux1_x86_64.whl (13.1 MB)\n",
105
+ "Collecting nvidia-cudnn-cu11==8.7.0.84 (from torch==2.2.0->torchtext)\n",
106
+ " Using cached https://download.pytorch.org/whl/cu118/nvidia_cudnn_cu11-8.7.0.84-py3-none-manylinux1_x86_64.whl (728.5 MB)\n",
107
+ "Collecting nvidia-cublas-cu11==11.11.3.6 (from torch==2.2.0->torchtext)\n",
108
+ " Downloading https://download.pytorch.org/whl/cu118/nvidia_cublas_cu11-11.11.3.6-py3-none-manylinux1_x86_64.whl (417.9 MB)\n",
109
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m417.9/417.9 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
110
+ "\u001b[?25hCollecting nvidia-cufft-cu11==10.9.0.58 (from torch==2.2.0->torchtext)\n",
111
+ " Downloading https://download.pytorch.org/whl/cu118/nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl (168.4 MB)\n",
112
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.4/168.4 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
113
+ "\u001b[?25hCollecting nvidia-curand-cu11==10.3.0.86 (from torch==2.2.0->torchtext)\n",
114
+ " Downloading https://download.pytorch.org/whl/cu118/nvidia_curand_cu11-10.3.0.86-py3-none-manylinux1_x86_64.whl (58.1 MB)\n",
115
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.1/58.1 MB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
116
+ "\u001b[?25hCollecting nvidia-cusolver-cu11==11.4.1.48 (from torch==2.2.0->torchtext)\n",
117
+ " Downloading https://download.pytorch.org/whl/cu118/nvidia_cusolver_cu11-11.4.1.48-py3-none-manylinux1_x86_64.whl (128.2 MB)\n",
118
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m128.2/128.2 MB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
119
+ "\u001b[?25hCollecting nvidia-cusparse-cu11==11.7.5.86 (from torch==2.2.0->torchtext)\n",
120
+ " Downloading https://download.pytorch.org/whl/cu118/nvidia_cusparse_cu11-11.7.5.86-py3-none-manylinux1_x86_64.whl (204.1 MB)\n",
121
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m204.1/204.1 MB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
122
+ "\u001b[?25hCollecting nvidia-nccl-cu11==2.19.3 (from torch==2.2.0->torchtext)\n",
123
+ " Downloading https://download.pytorch.org/whl/cu118/nvidia_nccl_cu11-2.19.3-py3-none-manylinux1_x86_64.whl (135.3 MB)\n",
124
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m135.3/135.3 MB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
125
+ "\u001b[?25hCollecting nvidia-nvtx-cu11==11.8.86 (from torch==2.2.0->torchtext)\n",
126
+ " Downloading https://download.pytorch.org/whl/cu118/nvidia_nvtx_cu11-11.8.86-py3-none-manylinux1_x86_64.whl (99 kB)\n",
127
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
128
+ "\u001b[?25hCollecting triton==2.2.0 (from torch==2.2.0->torchtext)\n",
129
+ " Downloading https://download.pytorch.org/whl/triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (167.9 MB)\n",
130
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m167.9/167.9 MB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
131
+ "\u001b[?25hRequirement already satisfied: urllib3>=1.25 in /usr/local/lib/python3.11/dist-packages (from torchdata==0.7.1->torchtext) (2.4.0)\n",
132
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->torchtext) (3.4.2)\n",
133
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->torchtext) (3.10)\n",
134
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->torchtext) (2025.6.15)\n",
135
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch==2.2.0->torchtext) (3.0.2)\n",
136
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy->torch==2.2.0->torchtext) (1.3.0)\n",
137
+ "Installing collected packages: triton, nvidia-nvtx-cu11, nvidia-nccl-cu11, nvidia-cusparse-cu11, nvidia-curand-cu11, nvidia-cufft-cu11, nvidia-cuda-runtime-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-cupti-cu11, nvidia-cublas-cu11, nvidia-cusolver-cu11, nvidia-cudnn-cu11, torch, torchdata, torchtext\n",
138
+ " Attempting uninstall: triton\n",
139
+ " Found existing installation: triton 3.2.0\n",
140
+ " Uninstalling triton-3.2.0:\n",
141
+ " Successfully uninstalled triton-3.2.0\n",
142
+ " Attempting uninstall: torch\n",
143
+ " Found existing installation: torch 2.7.1+cpu\n",
144
+ " Uninstalling torch-2.7.1+cpu:\n",
145
+ " Successfully uninstalled torch-2.7.1+cpu\n",
146
+ " Attempting uninstall: torchdata\n",
147
+ " Found existing installation: torchdata 0.11.0\n",
148
+ " Uninstalling torchdata-0.11.0:\n",
149
+ " Successfully uninstalled torchdata-0.11.0\n",
150
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
151
+ "torchaudio 2.6.0+cu124 requires torch==2.6.0, but you have torch 2.2.0+cu118 which is incompatible.\n",
152
+ "torchvision 0.21.0+cu124 requires torch==2.6.0, but you have torch 2.2.0+cu118 which is incompatible.\n",
153
+ "torchtune 0.6.1 requires torchdata==0.11.0, but you have torchdata 0.7.1 which is incompatible.\u001b[0m\u001b[31m\n",
154
+ "\u001b[0mSuccessfully installed nvidia-cublas-cu11-11.11.3.6 nvidia-cuda-cupti-cu11-11.8.87 nvidia-cuda-nvrtc-cu11-11.8.89 nvidia-cuda-runtime-cu11-11.8.89 nvidia-cudnn-cu11-8.7.0.84 nvidia-cufft-cu11-10.9.0.58 nvidia-curand-cu11-10.3.0.86 nvidia-cusolver-cu11-11.4.1.48 nvidia-cusparse-cu11-11.7.5.86 nvidia-nccl-cu11-2.19.3 nvidia-nvtx-cu11-11.8.86 torch-2.2.0+cu118 torchdata-0.7.1 torchtext-0.17.0+cpu triton-2.2.0\n",
155
+ "Requirement already satisfied: portalocker>=2.0.0 in /usr/local/lib/python3.11/dist-packages (3.2.0)\n",
156
+ "Requirement already satisfied: numpy<2 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n"
157
+ ]
158
+ }
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": 21,
164
+ "metadata": {
165
+ "id": "0l_UBDarnXHM"
166
+ },
167
+ "outputs": [],
168
+ "source": [
169
+ "import torch\n",
170
+ "import torch.nn as nn\n",
171
+ "from huggingface_hub import hf_hub_download\n",
172
+ "from torchtext.datasets import IMDB\n",
173
+ "from torchtext.data.utils import get_tokenizer\n",
174
+ "from torch.nn.utils.rnn import pad_sequence # For padding\n",
175
+ "# import torch.nn.functional as F # For softmax and multinomial sampling\n",
176
+ "import warnings\n",
177
+ "warnings.filterwarnings(\"ignore\")"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "source": [
183
+ "# --- 0. Setup Global Variables and Special Tokens ---\n",
184
+ "# Define special tokens and their indices\n",
185
+ "UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3\n",
186
+ "special_tokens = ['<unk>', '<pad>', '<bos>', '<eos>']"
187
+ ],
188
+ "metadata": {
189
+ "id": "Jo-Wq6FN_6zh"
190
+ },
191
+ "execution_count": 3,
192
+ "outputs": []
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "source": [
197
+ "train_iter, test_iter = IMDB(split=('train', 'test'))\n",
198
+ "tokenizer = get_tokenizer('basic_english')\n",
199
+ "\n",
200
+ "def yield_tokens(data_iter):\n",
201
+ " for _, text in data_iter:\n",
202
+ " yield tokenizer(text)"
203
+ ],
204
+ "metadata": {
205
+ "id": "CsUIMjfsQ7Rn"
206
+ },
207
+ "execution_count": 18,
208
+ "outputs": []
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": 10,
213
+ "metadata": {
214
+ "id": "1DUKTDjHuo-t"
215
+ },
216
+ "outputs": [],
217
+ "source": [
218
+ "# --- 2. Model Definition (Text Generator) ---\n",
219
+ "class TextGenerator(nn.Module):\n",
220
+ " def __init__(self, vocab_size, embed_dim, hidden_dim):\n",
221
+ " super().__init__()\n",
222
+ " # Embedding layer: Converts token IDs to dense vectors\n",
223
+ " # `padding_idx` ensures that PAD tokens are ignored (zeroed out)\n",
224
+ " self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_IDX)\n",
225
+ " # LSTM layer: Processes sequences. `batch_first=True` matches our (batch_size, seq_len) input\n",
226
+ " self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)\n",
227
+ " # Linear layer: Maps LSTM output to vocabulary size (logits for next token prediction)\n",
228
+ " self.fc = nn.Linear(hidden_dim, vocab_size)\n",
229
+ " self.init_weights()\n",
230
+ " self.hidden_dim = hidden_dim # Store hidden dimension for potentially initializing hidden states\n",
231
+ "\n",
232
+ " def init_weights(self):\n",
233
+ " # Initialize weights with a uniform distribution for better training stability\n",
234
+ " initrange = 0.1\n",
235
+ " self.embedding.weight.data.uniform_(-initrange, initrange)\n",
236
+ " self.fc.weight.data.uniform_(-initrange, initrange)\n",
237
+ " self.fc.bias.data.zero_()\n",
238
+ " # LSTM weights are often initialized by PyTorch's defaults, or more sophisticated methods.\n",
239
+ "\n",
240
+ " def forward(self, text, hidden=None):\n",
241
+ " # `text` shape: (batch_size, seq_len)\n",
242
+ " embedded = self.embedding(text) # Output shape: (batch_size, seq_len, embed_dim)\n",
243
+ " # Pass embedded sequence through LSTM.\n",
244
+ " # `hidden` can be passed for sequential inference (e.g., generating token by token).\n",
245
+ " output, hidden = self.lstm(embedded, hidden) # `output` shape: (batch_size, seq_len, hidden_dim)\n",
246
+ " # Apply linear layer to each time step's LSTM output\n",
247
+ " output = self.fc(output) # Output shape: (batch_size, seq_len, vocab_size) - logits for each token in sequence\n",
248
+ " return output, hidden # Return logits and the final hidden state"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": 5,
254
+ "metadata": {
255
+ "id": "dDQyS6TZu17f",
256
+ "colab": {
257
+ "base_uri": "https://localhost:8080/"
258
+ },
259
+ "outputId": "cfe6963e-c2a5-49f1-dff2-1c2835dfa777"
260
+ },
261
+ "outputs": [
262
+ {
263
+ "output_type": "stream",
264
+ "name": "stdout",
265
+ "text": [
266
+ "\n",
267
+ "Example text generation:\n"
268
+ ]
269
+ }
270
+ ],
271
+ "source": [
272
+ "# --- 6. Text Generation Example ---\n",
273
+ "print(\"\\nExample text generation:\")\n",
274
+ "\n",
275
+ "def generate_text(model, vocab, start_text, max_len=50, temperature=0.8):\n",
276
+ " model.eval() # Set model to evaluation mode\n",
277
+ " # Convert starting text to token IDs, prepending BOS\n",
278
+ " input_ids = [BOS_IDX] + text_pipeline(start_text)\n",
279
+ " generated_ids = list(input_ids)\n",
280
+ "\n",
281
+ " # Initialize LSTM's hidden state (h_0, c_0) to None\n",
282
+ " hidden = None\n",
283
+ " model_device = next(model.parameters()).device\n",
284
+ "\n",
285
+ " with torch.no_grad():\n",
286
+ " for _ in range(max_len):\n",
287
+ " # For generation, feed only the *last* generated token as input\n",
288
+ " # This is crucial for autoregressive generation\n",
289
+ " current_input_tensor = torch.tensor([[generated_ids[-1]]], dtype=torch.long).to(model_device) # Shape (1, 1)\n",
290
+ "\n",
291
+ " # Pass the single token and the current hidden state to the model\n",
292
+ " output_logits, hidden = model(current_input_tensor, hidden)\n",
293
+ "\n",
294
+ " # Apply temperature to logits for creativity/randomness\n",
295
+ " # We care about the prediction for the single token in `current_input_tensor`\n",
296
+ " prediction_logits = output_logits[:, -1, :] / temperature\n",
297
+ " probabilities = F.softmax(prediction_logits, dim=-1) # Convert logits to probabilities\n",
298
+ "\n",
299
+ " # Sample the next token from the probability distribution\n",
300
+ " next_token_id = torch.multinomial(probabilities, num_samples=1).item()\n",
301
+ "\n",
302
+ " generated_ids.append(next_token_id) # Add the sampled token to the generated sequence\n",
303
+ "\n",
304
+ " # Stop generation if EOS token is predicted\n",
305
+ " if next_token_id == EOS_IDX:\n",
306
+ " break\n",
307
+ "\n",
308
+ " # Convert generated token IDs back to human-readable text\n",
309
+ " generated_text = ' '.join(vocab.lookup_tokens(generated_ids))\n",
310
+ " # Clean up special tokens for display\n",
311
+ " generated_text = generated_text.replace(vocab.lookup_token(BOS_IDX), '')\n",
312
+ " generated_text = generated_text.replace(vocab.lookup_token(EOS_IDX), '')\n",
313
+ " generated_text = generated_text.replace(vocab.lookup_token(PAD_IDX), '')\n",
314
+ " return ' '.join(generated_text.split()) # Remove any extra spaces caused by token replacement"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "source": [
320
+ "# quantized_model_loaded = torch.load(\"model_quant_dynamic.pth\", map_location='cpu')\n",
321
+ "model_path = hf_hub_download(\"wbmlr/model_quant_dynamic\", \"model_quant_dynamic.pth\")\n",
322
+ "quantized_model_loaded = torch.load(model_path, map_location='cpu',weights_only=False)\n",
323
+ "vocab_path = hf_hub_download(\"wbmlr/model_quant_dynamic\", \"vocab.pth\")\n",
324
+ "vocab = torch.load(vocab_path,weights_only=False)"
325
+ ],
326
+ "metadata": {
327
+ "id": "hYIV5m3YOqvg"
328
+ },
329
+ "execution_count": 11,
330
+ "outputs": []
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "source": [
335
+ "# Text processing pipeline: converts raw text string to a list of token IDs\n",
336
+ "def text_pipeline(text):\n",
337
+ " return vocab(tokenizer(text))"
338
+ ],
339
+ "metadata": {
340
+ "id": "-FD9pWNHQtht"
341
+ },
342
+ "execution_count": 14,
343
+ "outputs": []
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "source": [
348
+ "quantized_model_loaded.eval()"
349
+ ],
350
+ "metadata": {
351
+ "colab": {
352
+ "base_uri": "https://localhost:8080/"
353
+ },
354
+ "id": "3qm7XSznP2a5",
355
+ "outputId": "4d2b3e96-bf5a-4341-82e8-90b7546e2a71"
356
+ },
357
+ "execution_count": 12,
358
+ "outputs": [
359
+ {
360
+ "output_type": "execute_result",
361
+ "data": {
362
+ "text/plain": [
363
+ "TextGenerator(\n",
364
+ " (embedding): Embedding(100686, 8, padding_idx=1)\n",
365
+ " (lstm): DynamicQuantizedLSTM(8, 16, batch_first=True)\n",
366
+ " (fc): DynamicQuantizedLinear(in_features=16, out_features=100686, dtype=torch.qint8, qscheme=torch.per_tensor_affine)\n",
367
+ ")"
368
+ ]
369
+ },
370
+ "metadata": {},
371
+ "execution_count": 12
372
+ }
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "code",
377
+ "source": [
378
+ "start_prompt = \"The movie is\"\n",
379
+ "quant_text = generate_text(quantized_model_loaded, vocab, start_prompt)\n",
380
+ "print(f\"Quantized Generated: {quant_text}\")"
381
+ ],
382
+ "metadata": {
383
+ "id": "RWg2FcQDSHF4",
384
+ "colab": {
385
+ "base_uri": "https://localhost:8080/"
386
+ },
387
+ "outputId": "1f938d07-f2a5-4ae9-e65b-fc90c82aff13"
388
+ },
389
+ "execution_count": 20,
390
+ "outputs": [
391
+ {
392
+ "output_type": "stream",
393
+ "name": "stdout",
394
+ "text": [
395
+ "Quantized Generated: the movie is devil many can it are ! a is , the the it\n"
396
+ ]
397
+ }
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "metadata": {
404
+ "id": "Ld1eYtwUGLs5"
405
+ },
406
+ "outputs": [],
407
+ "source": []
408
+ }
409
+ ]
410
+ }