| |
| import logging |
| from contextlib import contextmanager |
| from typing import Optional |
|
|
| import torch |
|
|
| from mmengine.device import (get_device, is_cuda_available, is_mlu_available, |
| is_npu_available) |
| from mmengine.logging import print_log |
| from mmengine.utils import digit_version |
| from mmengine.utils.dl_utils import TORCH_VERSION |
|
|
|
|
| @contextmanager |
| def autocast(device_type: Optional[str] = None, |
| dtype: Optional[torch.dtype] = None, |
| enabled: bool = True, |
| cache_enabled: Optional[bool] = None): |
| """A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``. |
| |
| Pytorch 1.5.0 provide ``torch.cuda.amp.autocast`` for running in |
| mixed precision , and update it to ``torch.autocast`` in 1.10.0. |
| Both interfaces have different arguments, and ``torch.autocast`` |
| support running with cpu additionally. |
| |
| This function provides a unified interface by wrapping |
| ``torch.autocast`` and ``torch.cuda.amp.autocast``, which resolves the |
| compatibility issues that ``torch.cuda.amp.autocast`` does not support |
| running mixed precision with cpu, and both contexts have different |
| arguments. We suggest users using this function in the code |
| to achieve maximized compatibility of different PyTorch versions. |
| |
| Note: |
| ``autocast`` requires pytorch version >= 1.5.0. If pytorch version |
| <= 1.10.0 and cuda is not available, it will raise an error with |
| ``enabled=True``, since ``torch.cuda.amp.autocast`` only support cuda |
| mode. |
| |
| Examples: |
| >>> # case1: 1.10 > Pytorch version >= 1.5.0 |
| >>> with autocast(): |
| >>> # run in mixed precision context |
| >>> pass |
| >>> with autocast(device_type='cpu'):: |
| >>> # raise error, torch.cuda.amp.autocast only support cuda mode. |
| >>> pass |
| >>> # case2: Pytorch version >= 1.10.0 |
| >>> with autocast(): |
| >>> # default cuda mixed precision context |
| >>> pass |
| >>> with autocast(device_type='cpu'): |
| >>> # cpu mixed precision context |
| >>> pass |
| >>> with autocast( |
| >>> device_type='cuda', enabled=True, cache_enabled=True): |
| >>> # enable precision context with more specific arguments. |
| >>> pass |
| |
| Args: |
| device_type (str, required): Whether to use 'cuda' or 'cpu' device. |
| enabled(bool): Whether autocasting should be enabled in the region. |
| Defaults to True |
| dtype (torch_dtype, optional): Whether to use ``torch.float16`` or |
| ``torch.bfloat16``. |
| cache_enabled(bool, optional): Whether the weight cache inside |
| autocast should be enabled. |
| """ |
| |
| |
| assert digit_version(TORCH_VERSION) >= digit_version('1.5.0'), ( |
| 'The minimum pytorch version requirements of mmengine is 1.5.0, but ' |
| f'got {TORCH_VERSION}') |
|
|
| if (digit_version('1.5.0') <= digit_version(TORCH_VERSION) < |
| digit_version('1.10.0')): |
| |
| |
| assert ( |
| device_type == 'cuda' or device_type == 'mlu' |
| or device_type is None), ( |
| 'Pytorch version under 1.10.0 only supports running automatic ' |
| 'mixed training with cuda or mlu') |
| if dtype is not None or cache_enabled is not None: |
| print_log( |
| f'{dtype} and {device_type} will not work for ' |
| '`autocast` since your Pytorch version: ' |
| f'{TORCH_VERSION} <= 1.10.0', |
| logger='current', |
| level=logging.WARNING) |
|
|
| if is_npu_available(): |
| with torch.npu.amp.autocast(enabled=enabled): |
| yield |
| elif is_mlu_available(): |
| with torch.mlu.amp.autocast(enabled=enabled): |
| yield |
| elif is_cuda_available(): |
| with torch.cuda.amp.autocast(enabled=enabled): |
| yield |
| else: |
| if not enabled: |
| yield |
| else: |
| raise RuntimeError( |
| 'If pytorch versions is between 1.5.0 and 1.10, ' |
| '`autocast` is only available in gpu mode') |
|
|
| else: |
| |
| |
| if cache_enabled is None: |
| cache_enabled = torch.is_autocast_cache_enabled() |
| device = get_device() |
| device_type = device if device_type is None else device_type |
|
|
| if device_type == 'cuda': |
| if dtype is None: |
| dtype = torch.get_autocast_gpu_dtype() |
|
|
| if dtype == torch.bfloat16 and not \ |
| torch.cuda.is_bf16_supported(): |
| raise RuntimeError( |
| 'Current CUDA Device does not support bfloat16. Please ' |
| 'switch dtype to float16.') |
|
|
| elif device_type == 'cpu': |
| if dtype is None: |
| dtype = torch.bfloat16 |
| assert dtype == torch.bfloat16, ( |
| 'In CPU autocast, only support `torch.bfloat16` dtype') |
|
|
| elif device_type == 'mlu': |
| pass |
|
|
| elif device_type == 'npu': |
| pass |
| elif device_type == 'musa': |
| if dtype is None: |
| dtype = torch.get_autocast_gpu_dtype() |
| with torch.musa.amp.autocast( |
| enabled=enabled, dtype=dtype, cache_enabled=cache_enabled): |
| yield |
| return |
| else: |
| |
| |
| |
| if enabled is False: |
| yield |
| return |
| else: |
| raise ValueError('User specified autocast device_type must be ' |
| f'cuda or cpu, but got {device_type}') |
|
|
| with torch.autocast( |
| device_type=device_type, |
| enabled=enabled, |
| dtype=dtype, |
| cache_enabled=cache_enabled): |
| yield |
|
|