| import wandb |
|
|
| import os |
| import shutil |
| import argparse |
| import torch |
| import torch.cuda.amp as amp |
| import torch.distributed as distrib |
| from torch.nn.utils import clip_grad_norm_ |
| from torch.utils.data import DataLoader, random_split |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from tqdm.auto import tqdm |
| |
| |
|
|
| from pepflow.utils.vc import get_version, has_changes |
| from pepflow.utils.misc import BlackHole, inf_iterator, load_config, seed_all, get_logger, get_new_log_dir, current_milli_time |
| from pepflow.utils.data import PaddingCollate |
| from pepflow.utils.train import ScalarMetricAccumulator, count_parameters, get_optimizer, get_scheduler, log_losses, recursive_to, sum_weighted_losses |
|
|
| from models_con.pep_dataloader import PepDataset |
| |
|
|
| from models_con.flow_model import FlowModel |
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config', type=str, default='./configs/angle/learn_angle.yaml') |
| parser.add_argument('--logdir', type=str, default="./logs") |
| parser.add_argument('--debug', action='store_true', default=False) |
| parser.add_argument('--device', type=str, default='cuda:0') |
| parser.add_argument('--num_workers', type=int, default=4) |
| parser.add_argument('--tag', type=str, default='') |
| parser.add_argument('--resume', type=str, default=None) |
| parser.add_argument('--name', type=str, default='pepflow') |
| args = parser.parse_args() |
|
|
| |
| branch, version = get_version() |
| version_short = '%s-%s' % (branch, version[:7]) |
| if has_changes() and not args.debug: |
| c = input('Start training anyway? (y/n) ') |
| if c != 'y': |
| exit() |
|
|
| |
| config, config_name = load_config(args.config) |
| seed_all(config.train.seed) |
| config['device'] = args.device |
|
|
| |
| if args.debug: |
| logger = get_logger('train', None) |
| writer = BlackHole() |
| else: |
| run = wandb.init(project=args.name, config=config, name='%s[%s]' % (config_name, args.tag)) |
| if args.resume: |
| log_dir = os.path.dirname(os.path.dirname(args.resume)) |
| else: |
| log_dir = get_new_log_dir(args.logdir, prefix='%s[%s]' % (config_name, version_short), tag=args.tag) |
| with open(os.path.join(log_dir, 'commit.txt'), 'w') as f: |
| f.write(branch + '\n') |
| f.write(version + '\n') |
| ckpt_dir = os.path.join(log_dir, 'checkpoints') |
| if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) |
| logger = get_logger('train', log_dir) |
| |
| |
| if not os.path.exists(os.path.join(log_dir, os.path.basename(args.config))): |
| shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config))) |
| logger.info(args) |
| logger.info(config) |
|
|
| |
| logger.info('Loading datasets...') |
| |
| |
| train_dataset = PepDataset(structure_dir = config.dataset.train.structure_dir, dataset_dir = config.dataset.train.dataset_dir, |
| name = config.dataset.train.name, transform=None, reset=config.dataset.train.reset) |
| |
| |
| train_loader = DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=True, collate_fn=PaddingCollate(), num_workers=args.num_workers, pin_memory=True) |
| train_iterator = inf_iterator(train_loader) |
| |
| logger.info('Train %d | Val %d' % (len(train_dataset), len(train_dataset))) |
|
|
| |
| logger.info('Building model...') |
| |
| model = FlowModel(config.model).to(args.device) |
| |
| logger.info('Number of parameters: %d' % count_parameters(model)) |
|
|
| |
| optimizer = get_optimizer(config.train.optimizer, model) |
| scheduler = get_scheduler(config.train.scheduler, optimizer) |
| optimizer.zero_grad() |
| it_first = 1 |
|
|
| |
| if args.resume is not None: |
| logger.info('Resuming from checkpoint: %s' % args.resume) |
| ckpt = torch.load(args.resume, map_location=args.device) |
| it_first = ckpt['iteration'] |
| model.load_state_dict(ckpt['model']) |
| logger.info('Resuming optimizer states...') |
| optimizer.load_state_dict(ckpt['optimizer']) |
| logger.info('Resuming scheduler states...') |
| scheduler.load_state_dict(ckpt['scheduler']) |
|
|
| def train(it): |
| time_start = current_milli_time() |
| model.train() |
|
|
| |
| batch = recursive_to(next(train_iterator), args.device) |
|
|
| |
| |
| loss_dict = model(batch) |
| loss = sum_weighted_losses(loss_dict, config.train.loss_weights) |
| |
| time_forward_end = current_milli_time() |
|
|
| if torch.isnan(loss): |
| print('NAN Loss!') |
| torch.save({'batch':batch,'loss':loss,'loss_dict':loss_dict,'model': model.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'scheduler': scheduler.state_dict(), |
| 'iteration': it,},os.path.join(log_dir,'nan.pt')) |
| loss = torch.tensor(0.,requires_grad=True).to(loss.device) |
|
|
| loss.backward() |
|
|
| |
| for param in model.parameters(): |
| if param.grad is not None: |
| if torch.isnan(param.grad).any(): |
| param.grad[torch.isnan(param.grad)] = 0 |
|
|
| orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm) |
|
|
| |
| |
| optimizer.step() |
| optimizer.zero_grad() |
| time_backward_end = current_milli_time() |
|
|
| |
| scalar_dict = {} |
| |
| scalar_dict.update({ |
| 'grad': orig_grad_norm, |
| 'lr': optimizer.param_groups[0]['lr'], |
| 'time_forward': (time_forward_end - time_start) / 1000, |
| 'time_backward': (time_backward_end - time_forward_end) / 1000, |
| }) |
| log_losses(loss, loss_dict, scalar_dict, it=it, tag='train', logger=logger) |
|
|
| def validate(it): |
| scalar_accum = ScalarMetricAccumulator() |
| with torch.no_grad(): |
| model.eval() |
|
|
| for i, batch in enumerate(tqdm(val_loader, desc='Validate', dynamic_ncols=True)): |
| |
| batch = recursive_to(batch, args.device) |
|
|
| |
| |
| loss_dict = model(batch) |
| loss = sum_weighted_losses(loss_dict, config.train.loss_weights) |
| scalar_accum.add(name='loss', value=loss, batchsize=len(batch['aa']), mode='mean') |
| for k, v in loss_dict['scalar'].items(): |
| scalar_accum.add(name=k, value=v, batchsize=len(batch['aa']), mode='mean') |
| |
| avg_loss = scalar_accum.get_average('loss') |
| summary = scalar_accum.log(it, 'val', logger=logger, writer=writer) |
| for k,v in summary.items(): |
| wandb.log({f'val/{k}': v}, step=it) |
| |
| if config.train.scheduler.type == 'plateau': |
| scheduler.step(avg_loss) |
| else: |
| scheduler.step() |
| return avg_loss |
|
|
| try: |
| for it in range(it_first, config.train.max_iters + 1): |
| train(it) |
| |
| |
| |
| if it % config.train.val_freq == 0: |
| ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it) |
| torch.save({ |
| 'config': config, |
| 'model': model.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'scheduler': scheduler.state_dict(), |
| 'iteration': it, |
| |
| }, ckpt_path) |
| except KeyboardInterrupt: |
| logger.info('Terminating...') |