abdurafay19 commited on
Commit
038d8bd
·
verified ·
1 Parent(s): 13d7827

Upload MNIST_Training.ipynb

Browse files
Files changed (1) hide show
  1. MNIST_Training.ipynb +466 -0
MNIST_Training.ipynb ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "id": "o_xNUk10GCIa"
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "import torch\n",
12
+ "import torch.nn as nn\n",
13
+ "import torch.optim as optim\n",
14
+ "from torch.utils.data import DataLoader\n",
15
+ "from torchvision import datasets, transforms\n",
16
+ "import numpy as np"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 2,
22
+ "metadata": {
23
+ "id": "CEfUc-G5GmJm"
24
+ },
25
+ "outputs": [],
26
+ "source": [
27
+ "CONFIG = {\n",
28
+ " \"batch_size\": 64,\n",
29
+ " \"epochs\": 50,\n",
30
+ " \"lr\": 0.003,\n",
31
+ " \"weight_decay\": 0.0001,\n",
32
+ " \"label_smoothing\": 0.1,\n",
33
+ " \"num_workers\": 2,\n",
34
+ " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
35
+ " \"seed\": 23,\n",
36
+ "}"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 3,
42
+ "metadata": {
43
+ "id": "SaIhfZfCG0Wn",
44
+ "colab": {
45
+ "base_uri": "https://localhost:8080/"
46
+ },
47
+ "outputId": "b59d824d-81be-4463-8fd7-0910b78acee2"
48
+ },
49
+ "outputs": [
50
+ {
51
+ "output_type": "stream",
52
+ "name": "stderr",
53
+ "text": [
54
+ "100%|██████████| 9.91M/9.91M [00:00<00:00, 20.3MB/s]\n",
55
+ "100%|██████████| 28.9k/28.9k [00:00<00:00, 508kB/s]\n",
56
+ "100%|██████████| 1.65M/1.65M [00:00<00:00, 4.58MB/s]\n",
57
+ "100%|██████████| 4.54k/4.54k [00:00<00:00, 9.38MB/s]\n"
58
+ ]
59
+ }
60
+ ],
61
+ "source": [
62
+ "train_transform = transforms.Compose([\n",
63
+ " transforms.RandomRotation(10),\n",
64
+ " transforms.RandomAffine(\n",
65
+ " degrees=0,\n",
66
+ " translate=(0.1, 0.1),\n",
67
+ " scale=(0.9, 1.1),\n",
68
+ " shear=5\n",
69
+ " ),\n",
70
+ " transforms.ToTensor(),\n",
71
+ " transforms.Normalize((0.1307,), (0.3081,)),\n",
72
+ "])\n",
73
+ "\n",
74
+ "test_transform = transforms.Compose([\n",
75
+ " transforms.ToTensor(),\n",
76
+ " transforms.Normalize((0.1307,), (0.3081,)),\n",
77
+ "])\n",
78
+ "\n",
79
+ "train_dataset = datasets.MNIST(root=\"./data\", train=True, download=True, transform=train_transform)\n",
80
+ "test_dataset = datasets.MNIST(root=\"./data\", train=False, download=True, transform=test_transform)\n",
81
+ "\n",
82
+ "train_loader = DataLoader(train_dataset, batch_size=CONFIG[\"batch_size\"], shuffle=True, num_workers=CONFIG[\"num_workers\"])\n",
83
+ "test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=CONFIG[\"num_workers\"])"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 4,
89
+ "metadata": {
90
+ "id": "3SHyrmaMHCIJ"
91
+ },
92
+ "outputs": [],
93
+ "source": [
94
+ "class Model(nn.Module):\n",
95
+ " def __init__(self):\n",
96
+ " super().__init__()\n",
97
+ "\n",
98
+ " self.conv_layers = nn.Sequential(\n",
99
+ " # Block 1: 1 -> 32 channels, 28x28 -> 14x14\n",
100
+ " nn.Conv2d(1, 32, kernel_size=3, padding=1),\n",
101
+ " nn.BatchNorm2d(32),\n",
102
+ " nn.ReLU(),\n",
103
+ " nn.MaxPool2d(2),\n",
104
+ " nn.Dropout2d(0.25),\n",
105
+ "\n",
106
+ " # Block 2: 32 -> 64 channels, 14x14 -> 7x7\n",
107
+ " nn.Conv2d(32, 64, kernel_size=3, padding=1),\n",
108
+ " nn.BatchNorm2d(64),\n",
109
+ " nn.ReLU(),\n",
110
+ " nn.MaxPool2d(2),\n",
111
+ " nn.Dropout2d(0.25),\n",
112
+ "\n",
113
+ " # Block 3: 64 -> 128 channels, 7x7 -> 3x3\n",
114
+ " nn.Conv2d(64, 128, kernel_size=3, padding=1),\n",
115
+ " nn.BatchNorm2d(128),\n",
116
+ " nn.ReLU(),\n",
117
+ " nn.MaxPool2d(2),\n",
118
+ " nn.Dropout2d(0.25),\n",
119
+ "\n",
120
+ " # Block 3: 128 -> 256 channels, 3x3 -> 1x1\n",
121
+ " nn.Conv2d(128, 256, kernel_size=1),\n",
122
+ " nn.BatchNorm2d(256),\n",
123
+ " nn.ReLU(),\n",
124
+ " nn.MaxPool2d(2),\n",
125
+ " nn.Dropout2d(0.25),\n",
126
+ " )\n",
127
+ "\n",
128
+ " self.fc_layers = nn.Sequential(\n",
129
+ " nn.Flatten(), # 256 * 1 * 1 = 256\n",
130
+ " nn.Linear(256 * 1 * 1, 128),\n",
131
+ " nn.ReLU(),\n",
132
+ " nn.Dropout(0.25),\n",
133
+ " nn.Linear(128, 10)\n",
134
+ " )\n",
135
+ "\n",
136
+ " def forward(self, x):\n",
137
+ " x = self.conv_layers(x)\n",
138
+ " x = self.fc_layers(x)\n",
139
+ " return x"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 5,
145
+ "metadata": {
146
+ "id": "rEp8D2U8Ke6d",
147
+ "colab": {
148
+ "base_uri": "https://localhost:8080/"
149
+ },
150
+ "outputId": "92157bae-28fb-4aa0-cb5a-75dc935adbe9"
151
+ },
152
+ "outputs": [
153
+ {
154
+ "output_type": "stream",
155
+ "name": "stdout",
156
+ "text": [
157
+ "Model parameters: 160,842\n"
158
+ ]
159
+ }
160
+ ],
161
+ "source": [
162
+ "model = Model().to(CONFIG[\"device\"])\n",
163
+ "total_params = sum(p.numel() for p in model.parameters())\n",
164
+ "print(f\"Model parameters: {total_params:,}\")"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": 6,
170
+ "metadata": {
171
+ "id": "lL1TZN8MJoun"
172
+ },
173
+ "outputs": [],
174
+ "source": [
175
+ "optimizer = optim.AdamW(\n",
176
+ " model.parameters(),\n",
177
+ " lr=CONFIG[\"lr\"],\n",
178
+ " weight_decay=CONFIG[\"weight_decay\"],\n",
179
+ ")\n",
180
+ "\n",
181
+ "# Warmup for 5 epochs, then cosine decay\n",
182
+ "scheduler = optim.lr_scheduler.OneCycleLR(\n",
183
+ " optimizer,\n",
184
+ " max_lr=CONFIG[\"lr\"],\n",
185
+ " steps_per_epoch=len(train_loader),\n",
186
+ " epochs=CONFIG[\"epochs\"],\n",
187
+ " pct_start=0.1, # 10% warmup\n",
188
+ " anneal_strategy=\"cos\",\n",
189
+ ")\n",
190
+ "\n",
191
+ "criterion = nn.CrossEntropyLoss(label_smoothing=CONFIG[\"label_smoothing\"])"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": 7,
197
+ "metadata": {
198
+ "id": "k_RjpCkLLXGj"
199
+ },
200
+ "outputs": [],
201
+ "source": [
202
+ "def train_epoch(model, loader, optimizer, scheduler, criterion, device):\n",
203
+ " model.train()\n",
204
+ " total_loss, correct, total = 0.0, 0, 0\n",
205
+ "\n",
206
+ " for images, labels in loader:\n",
207
+ " images, labels = images.to(device), labels.to(device)\n",
208
+ "\n",
209
+ " optimizer.zero_grad()\n",
210
+ " outputs = model(images)\n",
211
+ " loss = criterion(outputs, labels)\n",
212
+ " loss.backward()\n",
213
+ " nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
214
+ " optimizer.step()\n",
215
+ " scheduler.step()\n",
216
+ "\n",
217
+ " total_loss += loss.item() * images.size(0)\n",
218
+ " correct += (outputs.argmax(1) == labels).sum().item()\n",
219
+ " total += images.size(0)\n",
220
+ "\n",
221
+ " return total_loss / total, correct / total\n",
222
+ "\n",
223
+ "\n",
224
+ "def evaluate(model, loader, device, tta=False):\n",
225
+ " \"\"\"Evaluate with optional Test-Time Augmentation.\"\"\"\n",
226
+ " model.eval()\n",
227
+ " correct, total = 0, 0\n",
228
+ "\n",
229
+ " tta_transforms = [\n",
230
+ " transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),\n",
231
+ " transforms.Compose([transforms.RandomRotation(5), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),\n",
232
+ " transforms.Compose([transforms.RandomAffine(0, translate=(0.05, 0.05)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),\n",
233
+ " ]\n",
234
+ "\n",
235
+ " with torch.no_grad():\n",
236
+ " for images, labels in loader:\n",
237
+ " images, labels = images.to(device), labels.to(device)\n",
238
+ " outputs = model(images)\n",
239
+ " correct += (outputs.argmax(1) == labels).sum().item()\n",
240
+ " total += images.size(0)\n",
241
+ "\n",
242
+ " return correct / total"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": 8,
248
+ "metadata": {
249
+ "id": "aC2r9yTUO6l9",
250
+ "colab": {
251
+ "base_uri": "https://localhost:8080/"
252
+ },
253
+ "outputId": "7c804914-86c9-4863-aa04-f758d7b879c3"
254
+ },
255
+ "outputs": [
256
+ {
257
+ "output_type": "stream",
258
+ "name": "stdout",
259
+ "text": [
260
+ "\n",
261
+ "============================================================\n",
262
+ " Epoch Train Loss Train Acc Test Acc LR\n",
263
+ "============================================================\n",
264
+ " 1 1.5776 52.87% 95.11% 0.000395 ✓ BEST\n",
265
+ " 2 0.8644 87.52% 97.11% 0.001115 ✓ BEST\n",
266
+ " 3 0.7545 92.05% 98.13% 0.002006 ✓ BEST\n",
267
+ " 4 0.7104 93.52% 98.53% 0.002725 ✓ BEST\n",
268
+ " 5 0.6858 94.29% 98.75% 0.003000 ✓ BEST\n",
269
+ " 6 0.6660 95.03% 98.78% 0.002996 ✓ BEST\n",
270
+ " 7 0.6530 95.43% 98.84% 0.002985 ✓ BEST\n",
271
+ " 8 0.6437 95.54% 98.91% 0.002967 ✓ BEST\n",
272
+ " 9 0.6410 95.56% 99.14% 0.002942 ✓ BEST\n",
273
+ " 10 0.6323 95.84% 99.07% 0.002910\n",
274
+ " 11 0.6307 95.84% 99.10% 0.002870\n",
275
+ " 12 0.6261 95.96% 98.97% 0.002824\n",
276
+ " 13 0.6232 96.08% 99.05% 0.002772\n",
277
+ " 14 0.6203 96.12% 99.06% 0.002713\n",
278
+ " 15 0.6145 96.34% 99.02% 0.002649\n",
279
+ " 16 0.6124 96.50% 99.22% 0.002579 ✓ BEST\n",
280
+ " 17 0.6103 96.47% 99.09% 0.002504\n",
281
+ " 18 0.6075 96.63% 99.12% 0.002423\n",
282
+ " 19 0.6043 96.70% 99.09% 0.002339\n",
283
+ " 20 0.6038 96.70% 99.17% 0.002250\n",
284
+ " 21 0.6021 96.78% 99.27% 0.002157 ✓ BEST\n",
285
+ " 22 0.6010 96.78% 99.20% 0.002062\n",
286
+ " 23 0.5994 96.89% 99.24% 0.001963\n",
287
+ " 24 0.5955 97.02% 99.35% 0.001863 ✓ BEST\n",
288
+ " 25 0.5961 96.92% 99.21% 0.001760\n",
289
+ " 26 0.5919 97.17% 99.31% 0.001657\n",
290
+ " 27 0.5901 97.25% 99.30% 0.001552\n",
291
+ " 28 0.5897 97.18% 99.27% 0.001448\n",
292
+ " 29 0.5879 97.22% 99.27% 0.001343\n",
293
+ " 30 0.5891 97.17% 99.26% 0.001239\n",
294
+ " 31 0.5841 97.32% 99.25% 0.001137\n",
295
+ " 32 0.5839 97.36% 99.29% 0.001036\n",
296
+ " 33 0.5829 97.32% 99.37% 0.000938 ✓ BEST\n",
297
+ " 34 0.5800 97.46% 99.36% 0.000842\n",
298
+ " 35 0.5815 97.42% 99.38% 0.000750 ✓ BEST\n",
299
+ " 36 0.5778 97.52% 99.40% 0.000661 ✓ BEST\n",
300
+ " 37 0.5776 97.51% 99.37% 0.000576\n",
301
+ " 38 0.5770 97.60% 99.41% 0.000496 ✓ BEST\n",
302
+ " 39 0.5765 97.57% 99.38% 0.000421\n",
303
+ " 40 0.5758 97.57% 99.43% 0.000351 ✓ BEST\n",
304
+ " 41 0.5741 97.67% 99.41% 0.000286\n",
305
+ " 42 0.5741 97.61% 99.38% 0.000228\n",
306
+ " 43 0.5728 97.69% 99.40% 0.000176\n",
307
+ " 44 0.5731 97.71% 99.39% 0.000130\n",
308
+ " 45 0.5710 97.75% 99.38% 0.000090\n",
309
+ " 46 0.5700 97.79% 99.40% 0.000058\n",
310
+ " 47 0.5718 97.70% 99.38% 0.000033\n",
311
+ " 48 0.5712 97.77% 99.38% 0.000015\n",
312
+ " 49 0.5699 97.77% 99.38% 0.000004\n",
313
+ " 50 0.5717 97.70% 99.39% 0.000000\n",
314
+ "============================================================\n",
315
+ "\n",
316
+ "Best test accuracy: 99.43%\n"
317
+ ]
318
+ }
319
+ ],
320
+ "source": [
321
+ "best_acc = 0.0\n",
322
+ "history = {\"train_loss\": [], \"train_acc\": [], \"test_acc\": []}\n",
323
+ "\n",
324
+ "print(\"\\n\" + \"=\"*60)\n",
325
+ "print(f\"{'Epoch':>6} {'Train Loss':>10} {'Train Acc':>10} {'Test Acc':>10} {'LR':>10}\")\n",
326
+ "print(\"=\"*60)\n",
327
+ "\n",
328
+ "for epoch in range(1, CONFIG[\"epochs\"] + 1):\n",
329
+ " train_loss, train_acc = train_epoch(\n",
330
+ " model, train_loader, optimizer, scheduler, criterion, CONFIG[\"device\"]\n",
331
+ " )\n",
332
+ " test_acc = evaluate(model, test_loader, CONFIG[\"device\"])\n",
333
+ "\n",
334
+ " history[\"train_loss\"].append(train_loss)\n",
335
+ " history[\"train_acc\"].append(train_acc)\n",
336
+ " history[\"test_acc\"].append(test_acc)\n",
337
+ "\n",
338
+ " current_lr = scheduler.get_last_lr()[0]\n",
339
+ "\n",
340
+ " if test_acc > best_acc:\n",
341
+ " best_acc = test_acc\n",
342
+ " torch.save(model.state_dict(), \"mnist_best.pth\")\n",
343
+ " marker = \" ✓ BEST\"\n",
344
+ " else:\n",
345
+ " marker = \"\"\n",
346
+ "\n",
347
+ " print(f\"{epoch:>6} {train_loss:>10.4f} {train_acc*100:>9.2f}% {test_acc*100:>9.2f}% {current_lr:>10.6f}{marker}\")\n",
348
+ "\n",
349
+ "print(\"=\"*60)\n",
350
+ "print(f\"\\nBest test accuracy: {best_acc*100:.2f}%\")"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": 9,
356
+ "metadata": {
357
+ "id": "5QwlbG2YQ8Q4",
358
+ "colab": {
359
+ "base_uri": "https://localhost:8080/"
360
+ },
361
+ "outputId": "ae219648-1b25-4062-97f2-6dec61a96b17"
362
+ },
363
+ "outputs": [
364
+ {
365
+ "output_type": "stream",
366
+ "name": "stdout",
367
+ "text": [
368
+ "\n",
369
+ "Loading best model for final evaluation...\n",
370
+ "Final test accuracy: 99.43%\n"
371
+ ]
372
+ }
373
+ ],
374
+ "source": [
375
+ "print(\"\\nLoading best model for final evaluation...\")\n",
376
+ "model.load_state_dict(torch.load(\"mnist_best.pth\", map_location=CONFIG[\"device\"]))\n",
377
+ "final_acc = evaluate(model, test_loader, CONFIG[\"device\"])\n",
378
+ "print(f\"Final test accuracy: {final_acc*100:.2f}%\")"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "source": [
384
+ "def confusion_matrix(model, loader, device, num_classes=10):\n",
385
+ " model.eval()\n",
386
+ " matrix = np.zeros((num_classes, num_classes), dtype=int)\n",
387
+ " with torch.no_grad():\n",
388
+ " for images, labels in loader:\n",
389
+ " images = images.to(device)\n",
390
+ " preds = model(images).argmax(1).cpu().numpy()\n",
391
+ " for true, pred in zip(labels.numpy(), preds):\n",
392
+ " matrix[true][pred] += 1\n",
393
+ " return matrix\n",
394
+ "\n",
395
+ "cm = confusion_matrix(model, test_loader, CONFIG[\"device\"])\n",
396
+ "print(\"\\nConfusion Matrix (rows=true, cols=predicted):\")\n",
397
+ "print(\" \" + \" \".join(f\"{i:4}\" for i in range(10)))\n",
398
+ "for i, row in enumerate(cm):\n",
399
+ " errors = sum(row) - row[i]\n",
400
+ " print(f\"{i}: \" + \" \".join(f\"{v:4}\" for v in row) + f\" [{errors} errors]\")\n",
401
+ "\n",
402
+ "per_class_acc = cm.diagonal() / cm.sum(axis=1)\n",
403
+ "print(\"\\nPer-class accuracy:\")\n",
404
+ "for i, acc in enumerate(per_class_acc):\n",
405
+ " print(f\" Digit {i}: {acc*100:.1f}%\")"
406
+ ],
407
+ "metadata": {
408
+ "id": "tv567D7c8tT8",
409
+ "colab": {
410
+ "base_uri": "https://localhost:8080/"
411
+ },
412
+ "outputId": "9e585bd1-8862-428a-aa26-8886e087541b"
413
+ },
414
+ "execution_count": 10,
415
+ "outputs": [
416
+ {
417
+ "output_type": "stream",
418
+ "name": "stdout",
419
+ "text": [
420
+ "\n",
421
+ "Confusion Matrix (rows=true, cols=predicted):\n",
422
+ " 0 1 2 3 4 5 6 7 8 9\n",
423
+ "0: 980 0 0 0 0 0 0 0 0 0 [0 errors]\n",
424
+ "1: 0 1132 0 1 0 1 0 1 0 0 [3 errors]\n",
425
+ "2: 1 0 1025 2 0 0 1 3 0 0 [7 errors]\n",
426
+ "3: 0 0 0 1008 0 1 0 0 1 0 [2 errors]\n",
427
+ "4: 0 0 0 0 976 0 2 0 0 4 [6 errors]\n",
428
+ "5: 1 0 0 3 0 885 2 1 0 0 [7 errors]\n",
429
+ "6: 2 1 0 0 2 3 949 0 1 0 [9 errors]\n",
430
+ "7: 0 4 2 0 0 1 0 1020 0 1 [8 errors]\n",
431
+ "8: 0 0 2 1 0 1 0 0 968 2 [6 errors]\n",
432
+ "9: 0 0 0 0 4 1 0 3 1 1000 [9 errors]\n",
433
+ "\n",
434
+ "Per-class accuracy:\n",
435
+ " Digit 0: 100.0%\n",
436
+ " Digit 1: 99.7%\n",
437
+ " Digit 2: 99.3%\n",
438
+ " Digit 3: 99.8%\n",
439
+ " Digit 4: 99.4%\n",
440
+ " Digit 5: 99.2%\n",
441
+ " Digit 6: 99.1%\n",
442
+ " Digit 7: 99.2%\n",
443
+ " Digit 8: 99.4%\n",
444
+ " Digit 9: 99.1%\n"
445
+ ]
446
+ }
447
+ ]
448
+ }
449
+ ],
450
+ "metadata": {
451
+ "colab": {
452
+ "provenance": [],
453
+ "gpuType": "T4"
454
+ },
455
+ "kernelspec": {
456
+ "display_name": "Python 3",
457
+ "name": "python3"
458
+ },
459
+ "language_info": {
460
+ "name": "python"
461
+ },
462
+ "accelerator": "GPU"
463
+ },
464
+ "nbformat": 4,
465
+ "nbformat_minor": 0
466
+ }