| import gc |
| import logging |
| from typing import List, TypeVar |
|
|
| import torch |
| from torch.utils.data import Dataset |
|
|
| logger = logging.getLogger(__name__) |
| T = TypeVar("T") |
|
|
|
|
| def get_torch_device(device: str = "auto") -> str: |
| """ |
| Returns the device (string) to be used by PyTorch. |
| |
| `device` arg defaults to "auto" which will use: |
| - "cuda:0" if available |
| - else "mps" if available |
| - else "cpu". |
| """ |
|
|
| if device == "auto": |
| if torch.cuda.is_available(): |
| device = "cuda:0" |
| elif torch.backends.mps.is_available(): |
| device = "mps" |
| else: |
| device = "cpu" |
| logger.info(f"Using device: {device}") |
|
|
| return device |
|
|
|
|
| def tear_down_torch(): |
| """ |
| Teardown for PyTorch. |
| Clears GPU cache for both CUDA and MPS. |
| """ |
| gc.collect() |
| torch.cuda.empty_cache() |
| torch.mps.empty_cache() |
|
|
|
|
| class ListDataset(Dataset[T]): |
| def __init__(self, elements: List[T]): |
| self.elements = elements |
|
|
| def __len__(self) -> int: |
| return len(self.elements) |
|
|
| def __getitem__(self, idx: int) -> T: |
| return self.elements[idx] |
|
|