| from tqdm.auto import tqdm
|
| import torch
|
| from torch import nn
|
|
|
| from data.py import create_dataloaders
|
|
|
|
|
| train_loader, test_loader = create_dataloaders()
|
|
|
|
|
| def accuracy_fn(y_true, y_pred):
|
| correct = torch.eq(y_true, y_pred).sum().item()
|
| acc = (correct / len(y_pred)) * 100
|
| return acc
|
|
|
|
|
| loss_fn = nn.CrossEntropyLoss()
|
| optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
|
|
|
|
|
|
|
| def train_step(model):
|
| train_loss, train_accuracy = 0, 0
|
| model.train()
|
|
|
| for batch, (x,y) in enumerate(train_loader):
|
|
|
| y_logits = model(x)
|
| y_pred = y_logits.argmax(dim = 1)
|
|
|
|
|
| loss = loss_fn(y_logits, y)
|
| train_loss += loss.item()
|
| train_accuracy += accuracy_fn(y, y_pred)
|
|
|
|
|
| optimizer.zero_grad()
|
| loss.backward()
|
| optimizer.step()
|
|
|
|
|
| train_loss /= len(train_loader)
|
| train_accuracy /= len(train_loader)
|
|
|
|
|
| return train_loss, train_accuracy
|
|
|
|
|
| def test_step(model):
|
| test_loss, test_accuracy = 0, 0
|
|
|
| model.eval()
|
| with torch.inference_mode():
|
| for batch, (x,y) in enumerate(test_loader):
|
| y_logits = model(x)
|
| y_pred = y_logits.argmax(dim = 1)
|
|
|
| loss = loss_fn(y_logits, y)
|
| test_loss += loss.item()
|
| test_accuracy += accuracy_fn(y, y_pred)
|
|
|
|
|
| test_loss /= len(test_loader)
|
| test_accuracy /= len(test_loader)
|
|
|
|
|
| return test_loss, test_accuracy
|
|
|
| def train(model, epochs):
|
| """Trains a model for a given number of epochs
|
|
|
| Args: model and epochs
|
| Returns: The trained model and a dictionary of train/test loss and train/test accuracy for each epoch.
|
| """
|
|
|
| train_loss, test_loss, train_acc, test_acc = [], [], [], []
|
| for epoch in tqdm(range(epochs)):
|
|
|
| new_train_loss, new_train_acc = train_step(model)
|
| train_loss.append(new_train_loss)
|
| train_acc.append(new_train_acc)
|
|
|
|
|
| new_test_loss, new_test_acc = test_step(model)
|
| test_loss.append(new_test_loss)
|
| test_acc.append(new_test_acc)
|
|
|
|
|
| metrics = {"train_loss": train_loss, "test_loss" : test_loss,
|
| "train_acc": train_acc, "test_acc": test_acc}
|
|
|
| return model, metrics
|
|
|