|
|
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from tqdm import tqdm |
|
|
| from torchvision import transforms |
| import torchvision |
|
|
| from torch.utils.data import DataLoader |
|
|
| transform = transforms.Compose([ |
| transforms.Resize((512, 512)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| def make_dir(path): |
| import os |
| dir = os.path.exists(path) |
| if not dir: |
| os.makedirs(path) |
| make_dir('models') |
|
|
| batch_size = 8 |
|
|
| train_set = torchvision.datasets.ImageFolder(root='data/cat_vs_dog/train', transform=transform) |
| train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, |
| num_workers=0) |
|
|
| val_dataset = torchvision.datasets.ImageFolder(root='data/cat_vs_dog/val', transform=transform) |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, |
| num_workers=0) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| net = torchvision.models.resnet18(weights=True) |
| num_ftrs = net.fc.in_features |
| net.fc = nn.Linear(num_ftrs, 2) |
|
|
| criterion = nn.CrossEntropyLoss() |
| net = net.to(device) |
| optimizer = torch.optim.Adam(lr=0.0001, params=net.parameters()) |
| eposhs = 100 |
|
|
| for epoch in range(eposhs): |
| print(f'--------------------{epoch}--------------------') |
| correct_train = 0 |
| sum_loss_train = 0 |
| total_correct_train = 0 |
| for inputs, labels in tqdm(train_loader): |
| inputs = inputs.to(device) |
| labels = labels.to(device) |
|
|
| output = net(inputs) |
| loss = criterion(output, labels) |
| sum_loss_train = sum_loss_train + loss.item() |
| total_correct_train = total_correct_train + labels.size(0) |
| optimizer.zero_grad() |
| _, predicted = torch.max(output.data, 1) |
| loss.backward() |
| optimizer.step() |
| correct_train = correct_train + (predicted == labels).sum().item() |
|
|
| acc_train = correct_train / total_correct_train |
| print('训练准确率是{:.3f}%:'.format(acc_train*100) ) |
|
|
| net.eval() |
| correct_val = 0 |
| sum_loss_val = 0 |
| total_correct_val = 0 |
| for inputs, labels in tqdm(val_loader): |
| inputs = inputs.to(device) |
| labels = labels.to(device) |
| output = net(inputs) |
| loss = criterion(output, labels) |
| sum_loss_val = sum_loss_val + loss.item() |
|
|
| output = net(inputs) |
| total_correct_val = total_correct_val + labels.size(0) |
| optimizer.zero_grad() |
| _, predicted = torch.max(output.data, 1) |
| correct_val = correct_val + (predicted == labels).sum().item() |
|
|
| acc_val = correct_val / total_correct_val |
| print('验证准确率是{:.3f}%:'.format(acc_val*100) ) |
|
|
| torch.save(net,'models/{}-{:.5f}_{:.3f}%_{:.5f}_{:.3f}%.pth'.format(epoch,sum_loss_train,acc_train *100,sum_loss_val,acc_val*100)) |
|
|
|
|