File size: 958 Bytes
dbc3c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from typing import Any


def detect_accelerator() -> dict[str, Any]:
    try:
        import torch
    except Exception as exc:
        return {
            "torch_available": False,
            "cuda_api_available": False,
            "rocm_hip_version": None,
            "device_name": None,
            "error": str(exc),
        }

    cuda_available = bool(torch.cuda.is_available())
    device_name = torch.cuda.get_device_name(0) if cuda_available else None
    return {
        "torch_available": True,
        "cuda_api_available": cuda_available,
        "rocm_hip_version": getattr(torch.version, "hip", None),
        "cuda_version": getattr(torch.version, "cuda", None),
        "device_name": device_name,
        "device_count": torch.cuda.device_count() if cuda_available else 0,
    }


def torch_device_index() -> int:
    try:
        import torch
    except Exception:
        return -1
    return 0 if torch.cuda.is_available() else -1