| |
| |
| import os |
| import re |
| import yaml |
| from copy import deepcopy |
| from tqdm import tqdm |
|
|
| import numpy as np |
| import torch |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| from utils.oom_decorator import OOMReturn, safe_backward |
| from utils.logger import print_log |
|
|
| |
|
|
|
|
| class TrainConfig: |
| def __init__(self, save_dir, max_epoch, warmup=0, |
| metric_min_better=True, patience=3, |
| grad_clip=None, save_topk=-1, |
| grad_interval=1, |
| val_freq=1, |
| **kwargs): |
| self.save_dir = save_dir |
| self.max_epoch = max_epoch |
| self.warmup = warmup |
| self.metric_min_better = metric_min_better |
| self.patience = patience if patience > 0 else max_epoch |
| self.grad_clip = grad_clip |
| self.save_topk = save_topk |
| self.grad_interval = grad_interval |
| self.val_freq = val_freq |
| self.__dict__.update(kwargs) |
|
|
| def add_parameter(self, **kwargs): |
| self.__dict__.update(kwargs) |
|
|
| def __str__(self): |
| return str(self.__class__) + ': ' + str(self.__dict__) |
|
|
|
|
| class Trainer: |
| def __init__(self, model, train_loader, valid_loader, config: dict, save_config: dict): |
| self.model = model |
| self.config = TrainConfig(**config) |
| self.save_config = save_config |
| self.optimizer = self.get_optimizer() |
| sched_config = self.get_scheduler(self.optimizer) |
| if sched_config is None: |
| sched_config = { |
| 'scheduler': None, |
| 'frequency': None |
| } |
| self.scheduler = sched_config['scheduler'] |
| self.sched_freq = sched_config['frequency'] |
| self.train_loader = train_loader |
| self.valid_loader = valid_loader |
|
|
| |
| self.local_rank = -1 |
|
|
| |
| self.version = self._get_version() |
| self.config.save_dir = os.path.join(self.config.save_dir, f'version_{self.version}') |
| self.model_dir = os.path.join(self.config.save_dir, 'checkpoint') |
| self.writer = None |
| self.writer_buffer = {} |
|
|
| |
| self.global_step = 0 |
| self.valid_global_step = 0 |
| self.epoch = 0 |
| self.last_valid_metric = None |
| self.topk_ckpt_map = [] |
| self.patience = self.config.patience |
|
|
| @classmethod |
| def to_device(cls, data, device): |
| if isinstance(data, dict): |
| for key in data: |
| data[key] = cls.to_device(data[key], device) |
| elif isinstance(data, list) or isinstance(data, tuple): |
| res = [cls.to_device(item, device) for item in data] |
| data = type(data)(res) |
| elif hasattr(data, 'to'): |
| data = data.to(device) |
| return data |
|
|
| def _is_main_proc(self): |
| return self.local_rank == 0 or self.local_rank == -1 |
|
|
| def _get_version(self): |
| version, pattern = -1, r'version_(\d+)' |
| if os.path.exists(self.config.save_dir): |
| for fname in os.listdir(self.config.save_dir): |
| ver = re.findall(pattern, fname) |
| if len(ver): |
| version = max(int(ver[0]), version) |
| return version + 1 |
|
|
| def is_oom_return(self, value): |
| return isinstance(value, OOMReturn) |
|
|
| def _train_epoch(self, device): |
| if self.train_loader.sampler is not None and self.local_rank != -1: |
| self.train_loader.sampler.set_epoch(self.epoch) |
| t_iter = tqdm(self.train_loader) if self._is_main_proc() else self.train_loader |
| for batch in t_iter: |
| batch = self.to_device(batch, device) |
| loss = self.train_step(batch, self.global_step) |
| if self.is_oom_return(loss): |
| print_log(f'Out of memory, local rank {self.local_rank}', level='WARN') |
| loss = loss.fake_loss |
| elif torch.isnan(loss): |
| print_log(f'Loss is nan, local_rank {self.local_rank}', level='WARN') |
| loss = sum([p.norm() for p in self.model.parameters() if p.dtype == torch.float]) * 0.0 |
| self.optimizer.zero_grad() |
| backward_ok = safe_backward(loss, self.model) |
| if not backward_ok: |
| print_log(f'Backward out of memory, skip', level='WARN') |
| loss = loss.detach() |
| if self.config.grad_clip is not None: |
| ori_grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip) |
| |
| self.log('Grad Norm', ori_grad_norm.cpu(), self.global_step) |
| self.optimizer.step() |
| if hasattr(t_iter, 'set_postfix'): |
| t_iter.set_postfix(loss=loss.item(), version=self.version) |
| self.global_step += 1 |
| if self.sched_freq == 'batch': |
| self.scheduler.step() |
| if self.sched_freq == 'epoch': |
| self.scheduler.step() |
| self._train_epoch_end(device) |
| |
| def _train_epoch_end(self, device): |
| return |
|
|
| def _aggregate_val_metric(self, metric_arr): |
| return np.mean(metric_arr) |
|
|
| def _valid_epoch_begin(self, device): |
| return |
|
|
| def _valid_epoch(self, device): |
| metric_arr = [] |
| self.model.eval() |
| self._valid_epoch_begin(device) |
| with torch.no_grad(): |
| t_iter = tqdm(self.valid_loader) if self._is_main_proc() else self.valid_loader |
| for batch in t_iter: |
| batch = self.to_device(batch, device) |
| metric = self.valid_step(batch, self.valid_global_step) |
| metric_arr.append(metric.cpu().item()) |
| self.valid_global_step += 1 |
| |
| |
| valid_metric = self._aggregate_val_metric(metric_arr) |
| if self._is_main_proc(): |
| save_path = os.path.join(self.model_dir, f'epoch{self.epoch}_step{self.global_step}.ckpt') |
| module_to_save = self.model.module if self.local_rank == 0 else self.model |
| torch.save(module_to_save, save_path) |
| self._maintain_topk_checkpoint(valid_metric, save_path) |
| print_log(f'Validation: {valid_metric}, save path: {save_path}') |
| if self._metric_better(valid_metric): |
| self.patience = self.config.patience |
| else: |
| self.patience -= 1 |
| if self.sched_freq == 'val_epoch': |
| self.scheduler.step(valid_metric) |
| self.last_valid_metric = valid_metric |
| |
| for name in self.writer_buffer: |
| value = np.mean(self.writer_buffer[name]) |
| if self._is_main_proc(): |
| print_log(f'{name}: {value}') |
| self.log(name, value, self.epoch) |
| self.writer_buffer = {} |
| self._valid_epoch_end(device) |
| self.model.train() |
| |
| def _valid_epoch_end(self, device): |
| return |
|
|
| def _metric_better(self, new): |
| old = self.last_valid_metric |
| if old is None: |
| return True |
| if self.config.metric_min_better: |
| return new < old |
| else: |
| return old < new |
|
|
| def _maintain_topk_checkpoint(self, valid_metric, ckpt_path): |
| topk = self.config.save_topk |
| if self.config.metric_min_better: |
| better = lambda a, b: a < b |
| else: |
| better = lambda a, b: a > b |
| insert_pos = len(self.topk_ckpt_map) |
| for i, (metric, _) in enumerate(self.topk_ckpt_map): |
| if better(valid_metric, metric): |
| insert_pos = i |
| break |
| self.topk_ckpt_map.insert(insert_pos, (valid_metric, ckpt_path)) |
|
|
| |
| if topk > 0: |
| while len(self.topk_ckpt_map) > topk: |
| last_ckpt_path = self.topk_ckpt_map[-1][1] |
| os.remove(last_ckpt_path) |
| self.topk_ckpt_map.pop() |
|
|
| |
| topk_map_path = os.path.join(self.model_dir, 'topk_map.txt') |
| with open(topk_map_path, 'w') as fout: |
| for metric, path in self.topk_ckpt_map: |
| fout.write(f'{metric}: {path}\n') |
|
|
| def _modify_writer(self): |
| return |
|
|
| def train(self, device_ids, local_rank): |
| |
| self.local_rank = local_rank |
| |
| if self._is_main_proc(): |
| self.writer = SummaryWriter(self.config.save_dir) |
| self._modify_writer() |
| if not os.path.exists(self.model_dir): |
| os.makedirs(self.model_dir) |
| with open(os.path.join(self.config.save_dir, 'train_config.yaml'), 'w') as fout: |
| yaml.safe_dump(self.save_config, fout) |
| |
| main_device_id = local_rank if local_rank != -1 else device_ids[0] |
| device = torch.device('cpu' if main_device_id == -1 else f'cuda:{main_device_id}') |
| self.model.to(device) |
| if local_rank != -1: |
| print_log(f'Using data parallel, local rank {local_rank}, all {device_ids}') |
| self.model = torch.nn.parallel.DistributedDataParallel( |
| self.model, device_ids=[local_rank], output_device=local_rank, |
| find_unused_parameters=True |
| ) |
| else: |
| print_log(f'training on {device_ids}') |
| for _ in range(self.config.max_epoch): |
| print_log(f'epoch{self.epoch} starts') if self._is_main_proc() else 1 |
| self._train_epoch(device) |
| if (self.epoch + 1) % self.config.val_freq == 0: |
| print_log(f'validating ...') if self._is_main_proc() else 1 |
| self._valid_epoch(device) |
| self.epoch += 1 |
| if self.patience <= 0: |
| break |
|
|
| def log(self, name, value, step, val=False, batch_size=1): |
| if self._is_main_proc(): |
| if isinstance(value, torch.Tensor): |
| value = value.cpu().item() |
| if val: |
| if name not in self.writer_buffer: |
| self.writer_buffer[name] = [] |
| self.writer_buffer[name].extend([value] * batch_size) |
| else: |
| self.writer.add_scalar(name, value, step) |
|
|
| |
| def get_optimizer(self): |
| opt_cfg = deepcopy(self.config.optimizer) |
| cls = getattr(torch.optim, opt_cfg.pop('class')) |
| |
| optimizer = cls(filter(lambda p: p.requires_grad, self.model.parameters()), **opt_cfg) |
| return optimizer |
|
|
| |
| def get_scheduler(self, optimizer): |
| if not hasattr(self.config, 'scheduler'): |
| return None |
| sched_cfg = deepcopy(self.config.scheduler) |
| cls = getattr(torch.optim.lr_scheduler, sched_cfg.pop('class')) |
| freq = sched_cfg.pop('frequency') |
| return { |
| 'scheduler': cls(optimizer, **sched_cfg), |
| 'frequency': freq |
| } |
|
|
| |
| |
| def train_step(self, batch, batch_idx): |
| loss = self.model(batch) |
| self.log('Loss/train', loss, batch_idx) |
| return loss |
|
|
| |
| def valid_step(self, batch, batch_idx): |
| loss = self.model(batch) |
| self.log('Loss/validation', loss, batch_idx, val=True) |
| return loss |
|
|