| from typing import Any | |
| def detect_accelerator() -> dict[str, Any]: | |
| try: | |
| import torch | |
| except Exception as exc: | |
| return { | |
| "torch_available": False, | |
| "cuda_api_available": False, | |
| "rocm_hip_version": None, | |
| "device_name": None, | |
| "error": str(exc), | |
| } | |
| cuda_available = bool(torch.cuda.is_available()) | |
| device_name = torch.cuda.get_device_name(0) if cuda_available else None | |
| return { | |
| "torch_available": True, | |
| "cuda_api_available": cuda_available, | |
| "rocm_hip_version": getattr(torch.version, "hip", None), | |
| "cuda_version": getattr(torch.version, "cuda", None), | |
| "device_name": device_name, | |
| "device_count": torch.cuda.device_count() if cuda_available else 0, | |
| } | |
| def torch_device_index() -> int: | |
| try: | |
| import torch | |
| except Exception: | |
| return -1 | |
| return 0 if torch.cuda.is_available() else -1 | |