| |
| |
| |
| |
| |
| |
|
|
| import io |
| import os |
| import re |
| from typing import List |
| from timm.utils.model_ema import ModelEma |
| import torch |
| import torch.distributed as dist |
| from timm.utils import get_state_dict |
| import subprocess |
|
|
|
|
|
|
|
|
| def load_checkpoint(config, model, optimizer, logger, model_ema=None): |
| logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") |
| if config.MODEL.RESUME.startswith('https'): |
| checkpoint = torch.hub.load_state_dict_from_url( |
| config.MODEL.RESUME, map_location='cpu', check_hash=True) |
| else: |
| checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') |
| logger.info("Already loaded checkpoint to memory..") |
| msg = model.load_state_dict(checkpoint['model'], strict=False) |
| logger.info(msg) |
| max_accuracy = 0.0 |
| if config.MODEL_EMA: |
| if 'state_dict_ema' in checkpoint.keys(): |
| model_ema.ema.load_state_dict(checkpoint['state_dict_ema'], strict=False) |
|
|
| logger.info("Loaded state_dict_ema") |
| else: |
| model_ema.ema.load_state_dict(checkpoint['model'], strict=False) |
| logger.warning("Failed to find state_dict_ema, starting from loaded model weights") |
|
|
| if not config.EVAL_MODE and 'optimizer' in checkpoint and 'epoch' in checkpoint: |
| optimizer.load_state_dict(checkpoint['optimizer']) |
| config.defrost() |
| config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 |
| config.freeze() |
|
|
| logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") |
| if 'max_accuracy' in checkpoint: |
| max_accuracy = checkpoint['max_accuracy'] |
| |
| |
| return max_accuracy |
|
|
| def load_checkpoint_finetune(config, model, logger, model_ema=None): |
| logger.info(f"==============> Finetune {config.MODEL.FINETUNE}....................") |
| checkpoint = torch.load(config.MODEL.FINETUNE, map_location='cpu')['model'] |
| converted_weights = {} |
| keys = list(checkpoint.keys()) |
| for key in keys: |
| if re.match(r'cls.*', key): |
| |
| print(f'key: {key} is used for pretrain, discarded.') |
| continue |
| else: |
| converted_weights[key] = checkpoint[key] |
| msg = model.load_state_dict(converted_weights, strict=False) |
| logger.info(msg) |
| if model_ema is not None: |
| ema_msg = model_ema.ema.load_state_dict(converted_weights, strict=False) |
| logger.info(f"==============> Loaded Pretraind statedict into EMA....................") |
| logger.info(ema_msg) |
| del checkpoint |
| torch.cuda.empty_cache() |
| |
|
|
| def save_checkpoint(config, epoch, model, epoch_accuracy, max_accuracy, optimizer, logger, model_ema=None): |
| if model_ema is not None: |
| logger.info("Model EMA is not None...") |
| save_state = {'model': model.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'max_accuracy': max(max_accuracy, epoch_accuracy), |
| 'epoch': epoch, |
| 'state_dict_ema': get_state_dict(model_ema), |
| 'input': input, |
| 'config': config} |
| else: |
| save_state = {'model': model.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'max_accuracy': max(max_accuracy, epoch_accuracy), |
| 'epoch': epoch, |
| 'state_dict_ema': None, |
| 'input': input, |
| 'config': config} |
| |
| save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') |
| best_path = os.path.join(config.OUTPUT, f'best.pth') |
| |
| logger.info(f"{save_path} saving......") |
| torch.save(save_state, save_path) |
| if epoch_accuracy>max_accuracy: |
| torch.save(save_state, best_path) |
| logger.info(f"{save_path} saved !!!") |
|
|
|
|
| def get_grad_norm(parameters, norm_type=2): |
| if isinstance(parameters, torch.Tensor): |
| parameters = [parameters] |
| parameters = list(filter(lambda p: p.grad is not None, parameters)) |
| norm_type = float(norm_type) |
| total_norm = 0 |
| for p in parameters: |
| param_norm = p.grad.data.norm(norm_type) |
| total_norm += param_norm.item() ** norm_type |
| total_norm = total_norm ** (1. / norm_type) |
| return total_norm |
|
|
|
|
| def auto_resume_helper(output_dir,logger): |
| checkpoints = os.listdir(output_dir) |
| checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth') and ckpt.startswith('ckpt_')] |
| logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}") |
| if len(checkpoints) > 0: |
| latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) |
| logger.info(f"The latest checkpoint founded: {latest_checkpoint}") |
| resume_file = latest_checkpoint |
| else: |
| resume_file = None |
| return resume_file |
|
|
|
|
| def reduce_tensor(tensor): |
| rt = tensor.clone() |
| dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
| rt /= dist.get_world_size() |
| return rt |
|
|
| def denormalize(tensor: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: |
| """Denormalize a float tensor image with mean and standard deviation. |
| This transform does not support PIL Image. |
| |
| .. note:: |
| This transform acts out of place by default, i.e., it does not mutates the input tensor. |
| |
| See :class:`~torchvision.transforms.Normalize` for more details. |
| |
| Args: |
| tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized. |
| mean (sequence): Sequence of means for each channel. |
| std (sequence): Sequence of standard deviations for each channel. |
| inplace(bool,optional): Bool to make this operation inplace. |
| |
| Returns: |
| Tensor: Denormalized Tensor image. |
| """ |
| if not isinstance(tensor, torch.Tensor): |
| raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor))) |
|
|
| if not tensor.is_floating_point(): |
| raise TypeError('Input tensor should be a float tensor. Got {}.'.format(tensor.dtype)) |
|
|
| if tensor.ndim < 3: |
| raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = ' |
| '{}.'.format(tensor.size())) |
|
|
| if not inplace: |
| tensor = tensor.clone() |
|
|
| dtype = tensor.dtype |
| mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) |
| std = torch.as_tensor(std, dtype=dtype, device=tensor.device) |
| if (std == 0).any(): |
| raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype)) |
| if mean.ndim == 1: |
| mean = mean.view(-1, 1, 1) |
| if std.ndim == 1: |
| std = std.view(-1, 1, 1) |
| tensor.mul_(std).add_(mean).clip_(0.0, 1.0) |
| return tensor |
|
|
|
|