Spaces:
Running on Zero
Running on Zero
File size: 5,589 Bytes
6d5047c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | """Card 9 runtime device bootstrap helpers (AMD/ROCm-friendly)."""
from __future__ import annotations
import logging
import os
from dataclasses import asdict, dataclass
from typing import Optional
import torch
LOGGER = logging.getLogger(__name__)
@dataclass(frozen=True)
class RuntimeHealthReport:
"""Runtime/backend detection report for startup health checks."""
requested_device: str
selected_device: str
backend: str
cuda_available: bool
rocm_available: bool
mps_available: bool
strict_mode: bool
reason: str
def to_dict(self) -> dict:
return asdict(self)
def _env_bool(name: str, default: bool = False) -> bool:
raw = os.environ.get(name)
if raw is None:
return default
return str(raw).strip().lower() in ("1", "true", "yes", "on")
def _normalize_requested_device(requested: Optional[str]) -> str:
value = requested or os.environ.get("KIMODO_DEVICE") or os.environ.get("DEVICE") or "auto"
return str(value).strip().lower()
def _has_mps() -> bool:
backends = getattr(torch, "backends", None)
mps = getattr(backends, "mps", None)
if mps is None:
return False
is_available = getattr(mps, "is_available", None)
if callable(is_available):
try:
return bool(is_available())
except Exception: # pragma: no cover
return False
return False
def _backend_name(cuda_available: bool, rocm_available: bool, mps_available: bool) -> str:
if rocm_available:
return "rocm"
if cuda_available:
return "cuda"
if mps_available:
return "mps"
return "cpu"
def select_runtime_device(requested: Optional[str] = None) -> str:
"""Resolve runtime device with ROCm/CUDA/CPU fallback.
Resolution order:
- explicit requested argument
- environment variable KIMODO_DEVICE (or DEVICE)
- auto
If KIMODO_STRICT_DEVICE=true and requested accelerator is unavailable, raises ValueError.
"""
LOGGER.info("card9.select_runtime_device.start requested=%s", requested)
strict_mode = _env_bool("KIMODO_STRICT_DEVICE", default=False)
req = _normalize_requested_device(requested)
cuda_available = bool(torch.cuda.is_available())
rocm_available = cuda_available and bool(getattr(torch.version, "hip", None))
mps_available = _has_mps()
accelerator_aliases = {"cuda", "cuda:0", "gpu", "rocm", "hip", "amd"}
if req == "cpu":
selected = "cpu"
reason = "explicit_cpu"
elif req in ("mps", "apple"):
if mps_available:
selected = "mps"
reason = "explicit_mps"
elif strict_mode:
raise ValueError("Requested MPS device but MPS backend is unavailable")
else:
selected = "cpu"
reason = "mps_unavailable_fallback_cpu"
elif req in accelerator_aliases:
if cuda_available:
selected = "cuda:0"
reason = "explicit_accelerator_available"
elif strict_mode:
raise ValueError(f"Requested accelerator '{req}' but no torch accelerator is available")
else:
selected = "cpu"
reason = "accelerator_unavailable_fallback_cpu"
elif req == "auto":
if cuda_available:
selected = "cuda:0"
reason = "auto_accelerator"
elif mps_available:
selected = "mps"
reason = "auto_mps"
else:
selected = "cpu"
reason = "auto_cpu"
else:
# Preserve explicit torch device strings (e.g. cuda:1, cpu) when possible.
if req.startswith("cuda"):
if cuda_available:
selected = req
reason = "explicit_cuda_index"
elif strict_mode:
raise ValueError(f"Requested device '{req}' but CUDA/ROCm backend is unavailable")
else:
selected = "cpu"
reason = "explicit_cuda_unavailable_fallback_cpu"
else:
if strict_mode:
raise ValueError(f"Unknown device specifier '{req}'")
selected = "cpu"
reason = "unknown_device_fallback_cpu"
LOGGER.info("card9.select_runtime_device.exit selected=%s reason=%s", selected, reason)
return selected
def runtime_health_report(requested: Optional[str] = None) -> RuntimeHealthReport:
"""Return a startup runtime report suitable for health checks and logs."""
LOGGER.info("card9.runtime_health_report.start requested=%s", requested)
strict_mode = _env_bool("KIMODO_STRICT_DEVICE", default=False)
req = _normalize_requested_device(requested)
cuda_available = bool(torch.cuda.is_available())
rocm_available = cuda_available and bool(getattr(torch.version, "hip", None))
mps_available = _has_mps()
selected = select_runtime_device(req)
reason = "ok"
if selected == "cpu" and req in {"cuda", "cuda:0", "gpu", "rocm", "hip", "amd"}:
reason = "fallback_cpu"
report = RuntimeHealthReport(
requested_device=req,
selected_device=selected,
backend=_backend_name(cuda_available, rocm_available, mps_available),
cuda_available=cuda_available,
rocm_available=rocm_available,
mps_available=mps_available,
strict_mode=strict_mode,
reason=reason,
)
LOGGER.info(
"card9.runtime_health_report.exit requested=%s selected=%s backend=%s reason=%s",
report.requested_device,
report.selected_device,
report.backend,
report.reason,
)
return report |