| |
| |
| |
| |
| import torch |
|
|
| def compute_batch_accuracy(pred, label): |
| correct = (pred == label).sum() |
| return correct,label.size(0) |
|
|
| def compute_set_accuracy(model, test_loader): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| correct = 0 |
| total = 0 |
| with torch.no_grad(): |
| for data in test_loader: |
| inputs, labels = data |
| |
| inputs = inputs.to(device) |
| labels = labels.to(device) |
| outputs = model(inputs) |
| |
| correct_batch, total_batch = compute_batch_accuracy(torch.argmax(outputs, dim=1), labels) |
| correct += correct_batch |
| total += total_batch |
| |
| return correct/total |