| """ |
| This file creates a simple lenet network using the MNIST dataset. |
| """ |
|
|
| import random |
|
|
| import torch |
| from torchvision import datasets, transforms |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
|
|
| def get_mnist_dataset(): |
| transform = transforms.ToTensor() |
| train_set = datasets.MNIST(root='./data', train=True, transform=transform, download=True) |
| test_set = datasets.MNIST(root='./data', train=False, transform=transform, download=True) |
| return train_set, test_set |
|
|
| |
|
|
| class Classifier(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.network = nn.Sequential( |
| nn.Conv2d(1, 32, 5), |
| nn.ReLU(), |
| nn.MaxPool2d(2, 2), |
| nn.Conv2d(32, 32, 5), |
| nn.ReLU(), |
| nn.MaxPool2d(2, 2), |
| nn.Flatten(), |
| nn.Linear(32*4*4, 100), |
| nn.ReLU(), |
| nn.Linear(100, 100), |
| nn.ReLU(), |
| nn.Linear(100, 10) |
| ) |
|
|
| def forward(self, x): |
| return self.network(x) |
|
|
| |
|
|
| def compute_accuracy(model, data_set, nb_samples): |
| nb_valid = 0 |
| for it in range(nb_samples): |
| |
| sample_idx = torch.randint(len(data_set), size=(1,)).item() |
| img, label = data_set[sample_idx] |
| |
| x = torch.reshape(img, (1,1,28,28)) |
| y_h = model.forward(x) |
| pred_label = torch.argmax(y_h).item() |
| if label == pred_label : |
| nb_valid = nb_valid + 1 |
| return nb_valid / nb_samples |
|
|
| |
|
|
| def train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier): |
| accuracy_history = [] |
| for it in range(NB_ITERATION): |
| sample_idx = random.randint(0, len(train_set)-1) |
| img, label = train_set[sample_idx] |
| x = torch.flatten(img) |
| x = torch.reshape(x, (1,1,28,28)) |
| y = torch.zeros(1,10) |
| y[0][label] = 1 |
| y_h = classifier.forward(x) |
| |
| l = F.mse_loss(y, y_h) |
| l.backward() |
| for p in classifier.parameters(): |
| with torch.no_grad(): |
| p -= 0.01 * p.grad |
| p.grad.zero_() |
|
|
| if it % CHECK_PERIOD == 0: |
| accuracy = compute_accuracy(classifier, test_set, CHECK_PERIOD) |
| accuracy_history.append(accuracy) |
| print(f'it {it}: accuracy = {accuracy:.8f} ') |
|
|
|
|
| def create_lenet(): |
| |
| train_set, test_set = get_mnist_dataset() |
|
|
| |
| classifier = Classifier() |
|
|
| |
| NB_ITERATION = 50000 |
| CHECK_PERIOD = 3000 |
| print("NB_ITERATIONS = ", NB_ITERATION) |
| print("CHECK_PERIOD = ", CHECK_PERIOD) |
| print("\nTraining LeNet...") |
| train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier) |
|
|
| |
| x = torch.Tensor(1,1,28,28) |
| torch.onnx.export(classifier.network, x, 'lenet.onnx', verbose=False, input_names=[ "input" ], output_names=[ "output" ]) |