| |
| def ResNet50(num_classes, channels=3): |
| return ResNet(Bottleneck, [3,4,6,3], num_classes, channels) |
| |
| model = ResNet50(num_classes=1000) |
|
|
| |
| |
|
|
| |
| model = model.to(device) |
|
|
| |
| criterion = torch.nn.CrossEntropyLoss() |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
|
|
| |
| def evaluate_model(model, val_loader, criterion): |
| model.eval() |
| val_loss = 0.0 |
| correct = 0 |
| total = 0 |
| class_correct = [0] * len(val_dataset.classes) |
| class_total = [0] * len(val_dataset.classes) |
| |
| with torch.no_grad(): |
| for inputs, labels in tqdm(val_loader, desc="Validating"): |
| inputs, labels = inputs.to(device), labels.to(device) |
| |
| outputs = model(inputs) |
| loss = criterion(outputs, labels) |
| val_loss += loss.item() |
|
|
| _, predicted = torch.max(outputs, 1) |
| correct += (predicted == labels).sum().item() |
| total += labels.size(0) |
|
|
| for i in range(len(labels)): |
| label = labels[i] |
| class_correct[label] += (predicted[i] == label).item() |
| class_total[label] += 1 |
|
|
| val_loss /= len(val_loader) |
| accuracy = 100.0 * correct / total |
| per_class_accuracy = { |
| val_dataset.classes[i]: 100.0 * class_correct[i] / class_total[i] |
| for i in range(len(val_dataset.classes)) |
| if class_total[i] > 0 |
| } |
| return val_loss, accuracy, per_class_accuracy |
|
|
| |
| print(f'Training the model on ImageNet') |
| for epoch in range(num_epochs): |
| model.train() |
| running_loss = 0.0 |
| correct = 0 |
| total = 0 |
|
|
| for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"): |
| inputs, labels = inputs.to(device), labels.to(device) |
|
|
| |
| optimizer.zero_grad() |
|
|
| |
| outputs = model(inputs) |
| loss = criterion(outputs, labels) |
|
|
| |
| loss.backward() |
| optimizer.step() |
|
|
| running_loss += loss.item() |
|
|
| |
| _, predicted = torch.max(outputs, 1) |
| correct += (predicted == labels).sum().item() |
| total += labels.size(0) |
|
|
| |
| train_loss = running_loss / len(train_loader) |
| train_accuracy = 100.0 * correct / total |
|
|
| print(f"Epoch {epoch+1}/{num_epochs} - Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%") |
|
|
| |
| print(f"Validating the model on unseen data after training...") |
| val_loss, val_accuracy, per_class_accuracy = evaluate_model(model, val_loader, criterion) |
| print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%") |
| print("Per-class Accuracy:") |
| for class_name, acc in per_class_accuracy.items(): |
| print(f"{class_name}: {acc:.2f}%") |
|
|
| |
| torch.save(model.state_dict(), "resnet50_imagenet.pth") |
| print("Model saved as resnet50_imagenet_last_epoch.pth") |