| import functools |
| import gc |
|
|
| import torch |
| try: |
| |
| |
| import intel_extension_for_pytorch as ipex |
| except Exception: |
| pass |
|
|
|
|
| try: |
| HAS_CUDA = torch.cuda.is_available() |
| except Exception: |
| HAS_CUDA = False |
|
|
| try: |
| HAS_MPS = torch.backends.mps.is_available() |
| except Exception: |
| HAS_MPS = False |
|
|
| try: |
| HAS_XPU = torch.xpu.is_available() |
| except Exception: |
| HAS_XPU = False |
|
|
|
|
| def clean_memory(): |
| gc.collect() |
| if HAS_CUDA: |
| torch.cuda.empty_cache() |
| if HAS_XPU: |
| torch.xpu.empty_cache() |
| if HAS_MPS: |
| torch.mps.empty_cache() |
|
|
|
|
| def clean_memory_on_device(device: torch.device): |
| r""" |
| Clean memory on the specified device, will be called from training scripts. |
| """ |
| gc.collect() |
|
|
| |
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
| if device.type == "xpu": |
| torch.xpu.empty_cache() |
| if device.type == "mps": |
| torch.mps.empty_cache() |
|
|
|
|
| @functools.lru_cache(maxsize=None) |
| def get_preferred_device() -> torch.device: |
| r""" |
| Do not call this function from training scripts. Use accelerator.device instead. |
| """ |
| if HAS_CUDA: |
| device = torch.device("cuda") |
| elif HAS_XPU: |
| device = torch.device("xpu") |
| elif HAS_MPS: |
| device = torch.device("mps") |
| else: |
| device = torch.device("cpu") |
| print(f"get_preferred_device() -> {device}") |
| return device |
|
|
|
|
| def init_ipex(): |
| """ |
| Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. |
| |
| This function should run right after importing torch and before doing anything else. |
| |
| If xpu is not available, this function does nothing. |
| """ |
| try: |
| if HAS_XPU: |
| from library.ipex import ipex_init |
|
|
| is_initialized, error_message = ipex_init() |
| if not is_initialized: |
| print("failed to initialize ipex:", error_message) |
| else: |
| return |
| except Exception as e: |
| print("failed to initialize ipex:", e) |
|
|