| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| import torch |
| import torch.nn as nn |
| from torch.optim import Adam |
| from tqdm import tqdm |
| import torch.nn.functional as F |
|
|
| |
| from tensorboardX import SummaryWriter |
| import numpy as np |
|
|
| |
| |
| from model import * |
| import lovasz_losses as L |
|
|
| |
| |
| import sys |
| import os |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| model_dir = './model/s3_net_model.pth' |
| NUM_ARGS = 3 |
| NUM_EPOCHS = 20000 |
| BATCH_SIZE = 1024 |
| LEARNING_RATE = "lr" |
| BETAS = "betas" |
| EPS = "eps" |
| WEIGHT_DECAY = "weight_decay" |
|
|
| |
| NUM_INPUT_CHANNELS = 3 |
| NUM_OUTPUT_CHANNELS = 10 |
| BETA = 0.01 |
|
|
| |
| |
| set_seed(SEED1) |
|
|
| |
| |
| def adjust_learning_rate(optimizer, epoch): |
| lr = 1e-4 |
| if epoch > 50000: |
| lr = 2e-5 |
| if epoch > 480000: |
| |
| lr = lr * (0.1 ** (epoch // 110000)) |
| |
| |
| for param_group in optimizer.param_groups: |
| param_group['lr'] = lr |
|
|
|
|
| |
| def train(model, dataloader, dataset, device, optimizer, ce_criterion, lovasz_criterion, class_weights, epoch, epochs): |
| |
| model.train() |
| |
| running_loss = 0.0 |
| |
| kl_avg_loss = 0.0 |
| |
| ce_avg_loss = 0.0 |
|
|
| counter = 0 |
| |
| num_batches = int(len(dataset)/dataloader.batch_size) |
| for i, batch in tqdm(enumerate(dataloader), total=num_batches): |
| |
| counter += 1 |
| |
| scans = batch['scan'] |
| scans = scans.to(device) |
| intensities = batch['intensity'] |
| intensities = intensities.to(device) |
| angle_incidence = batch['angle_incidence'] |
| angle_incidence = angle_incidence.to(device) |
| labels = batch['label'] |
| labels = labels.to(device) |
|
|
| batch_size = scans.size(0) |
|
|
| |
| optimizer.zero_grad() |
|
|
| |
| semantic_scan, semantic_channels, kl_loss = model(scans, intensities, angle_incidence) |
| |
| ce_loss = ce_criterion(semantic_channels, labels.to(torch.long)).div(batch_size) |
| lovasz_loss, _ = lovasz_criterion(semantic_channels, labels.to(torch.long)) |
| lovasz_loss = lovasz_loss.mul(class_weights.to("cuda")).sum() |
| |
| loss = ce_loss + BETA*kl_loss + lovasz_loss |
| |
| loss.backward(torch.ones_like(loss)) |
| optimizer.step() |
| |
| |
| if torch.cuda.device_count() > 1: |
| loss = loss.mean() |
| ce_loss = ce_loss.mean() |
| kl_loss = lovasz_loss.mean() |
|
|
| running_loss += loss.item() |
| |
| kl_avg_loss += lovasz_loss.item() |
| |
| ce_avg_loss += ce_loss.item() |
|
|
| |
| if(i % 512 == 0): |
| print('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, CE_Loss: {:.4f}, Lovasz_Loss: {:.4f}' |
| .format(epoch, epochs, i + 1, num_batches, loss.item(), ce_loss.item(), lovasz_loss.item())) |
| |
| train_loss = running_loss / counter |
| train_kl_loss = kl_avg_loss / counter |
| train_ce_loss = ce_avg_loss / counter |
|
|
| return train_loss, train_kl_loss, train_ce_loss |
|
|
| |
| def validate(model, dataloader, dataset, device, ce_criterion, lovasz_criterion, class_weights): |
| |
| model.eval() |
| |
| running_loss = 0.0 |
| |
| kl_avg_loss = 0.0 |
| |
| ce_avg_loss = 0.0 |
|
|
| counter = 0 |
| |
| num_batches = int(len(dataset)/dataloader.batch_size) |
| with torch.no_grad(): |
| for i, batch in tqdm(enumerate(dataloader), total=num_batches): |
| |
| counter += 1 |
| |
| scans = batch['scan'] |
| scans = scans.to(device) |
| intensities = batch['intensity'] |
| intensities = intensities.to(device) |
| angle_incidence = batch['angle_incidence'] |
| angle_incidence = angle_incidence.to(device) |
| labels = batch['label'] |
| labels = labels.to(device) |
|
|
| batch_size = scans.size(0) |
|
|
| |
| semantic_scan, semantic_channels, kl_loss = model(scans, intensities, angle_incidence) |
| |
| ce_loss = ce_criterion(semantic_channels, labels.to(torch.long)).div(batch_size) |
| lovasz_loss, _ = lovasz_criterion(semantic_channels, labels.to(torch.long)) |
| lovasz_loss = lovasz_loss.mul(class_weights.to("cuda")).sum() |
| |
| loss = ce_loss + BETA*kl_loss + lovasz_loss |
| |
| if torch.cuda.device_count() > 1: |
| loss = loss.mean() |
| ce_loss = ce_loss.mean() |
| kl_loss = lovasz_loss.mean() |
|
|
| running_loss += loss.item() |
| |
| kl_avg_loss += lovasz_loss.item() |
| |
| ce_avg_loss += ce_loss.item() |
|
|
| val_loss = running_loss / counter |
| val_kl_loss = kl_avg_loss / counter |
| val_ce_loss = ce_avg_loss / counter |
|
|
| return val_loss, val_kl_loss, val_ce_loss |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| def main(argv): |
| |
| |
| if(len(argv) != NUM_ARGS): |
| print("usage: python train.py [MDL_PATH] [TRAIN_PATH] [DEV_PATH] [TRAIN_MASK_PATH] [DEV_MASK_PATH]") |
| exit(-1) |
|
|
| |
| mdl_path = argv[0] |
| pTrain = argv[1] |
| pDev = argv[2] |
|
|
| |
| odir = os.path.dirname(mdl_path) |
|
|
| |
| if not os.path.exists(odir): |
| os.makedirs(odir) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| print('...Start reading data...') |
| |
| |
| train_dataset = VaeTestDataset(pTrain, 'train') |
| train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4, \ |
| shuffle=True, drop_last=True, pin_memory=True) |
|
|
| |
| |
| dev_dataset = VaeTestDataset(pDev, 'dev') |
| dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, num_workers=2, \ |
| shuffle=True, drop_last=True, pin_memory=True) |
|
|
| |
| class_weights = np.array([2.514399, 1.4917144, 0.51608694, 0.659483, 1.0900991, 1.6461798, 0.32852992, 1.5633508, 0.9236576, 0.10251398]) |
|
|
| |
| class_weights = torch.Tensor(class_weights) |
| print("class weights: ", class_weights) |
| class_weights.to(device) |
| print('...Finish reading data...') |
|
|
| |
| model = S3Net(input_channels=NUM_INPUT_CHANNELS, |
| output_channels=NUM_OUTPUT_CHANNELS) |
| |
| model.to(device) |
|
|
| |
| opt_params = { LEARNING_RATE: 0.001, |
| BETAS: (.9,0.999), |
| EPS: 1e-08, |
| WEIGHT_DECAY: .001 } |
| |
| ce_criterion = nn.CrossEntropyLoss(reduction='sum', weight=class_weights) |
| ce_criterion.to(device) |
| lovasz_criterion = L.LovaszSoftmax(reduction='sum', ignore_index=0) |
| lovasz_criterion.to(device) |
| |
| optimizer = Adam(model.parameters(), **opt_params) |
|
|
| |
| epochs = NUM_EPOCHS |
|
|
| |
| if os.path.exists(mdl_path): |
| checkpoint = torch.load(mdl_path) |
| model.load_state_dict(checkpoint['model']) |
| optimizer.load_state_dict(checkpoint['optimizer']) |
| start_epoch = checkpoint['epoch'] |
| print('Load epoch {} success'.format(start_epoch)) |
| else: |
| start_epoch = 0 |
| |
| |
| |
| print('No trained models, restart training') |
|
|
| |
| if torch.cuda.device_count() > 1: |
| print("Let's use 2 of total", torch.cuda.device_count(), "GPUs!") |
| |
| model = nn.DataParallel(model) |
| |
| model.to(device) |
|
|
| |
| writer = SummaryWriter('runs') |
|
|
| epoch_num = 0 |
| for epoch in range(start_epoch+1, epochs): |
| |
| adjust_learning_rate(optimizer, epoch) |
| |
| |
| |
| train_epoch_loss, train_kl_epoch_loss, train_ce_epoch_loss = train( |
| model, train_dataloader, train_dataset, device, optimizer, ce_criterion, lovasz_criterion, class_weights, epoch, epochs |
| ) |
| valid_epoch_loss, valid_kl_epoch_loss, valid_ce_epoch_loss = validate( |
| model, dev_dataloader, dev_dataset, device, ce_criterion, lovasz_criterion, class_weights |
| ) |
| |
| |
| writer.add_scalar('training loss', |
| train_epoch_loss, |
| epoch) |
| writer.add_scalar('training kl loss', |
| train_kl_epoch_loss, |
| epoch) |
| writer.add_scalar('training ce loss', |
| train_ce_epoch_loss, |
| epoch) |
|
|
| writer.add_scalar('validation loss', |
| valid_epoch_loss, |
| epoch) |
| writer.add_scalar('validation kl loss', |
| valid_kl_epoch_loss, |
| epoch) |
| writer.add_scalar('validation ce loss', |
| valid_ce_epoch_loss, |
| epoch) |
|
|
| print('Train set: Average loss: {:.4f}'.format(train_epoch_loss)) |
| print('Validation set: Average loss: {:.4f}'.format(valid_epoch_loss)) |
| |
| |
| if(epoch % 2000 == 0): |
| if torch.cuda.device_count() > 1: |
| state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} |
| else: |
| state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} |
| path='./model/model' + str(epoch) +'.pth' |
| torch.save(state, path) |
|
|
| epoch_num = epoch |
|
|
| |
| if torch.cuda.device_count() > 1: |
| state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num} |
| else: |
| state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num} |
| torch.save(state, mdl_path) |
|
|
| |
| |
|
|
| return True |
| |
| |
|
|
|
|
| |
| |
| if __name__ == '__main__': |
| main(sys.argv[1:]) |
| |
| |
|
|