from __future__ import annotations import importlib.util import shutil from dataclasses import dataclass from typing import Dict import torch @dataclass(frozen=True) class XQSBackendReport: torch_version: str cuda_available: bool cuda_device_name: str bf16_supported: bool torch_compile_available: bool triton_available: bool deepspeed_available: bool bitsandbytes_available: bool flash_attn_available: bool nvcc_available: bool def as_dict(self) -> Dict[str, object]: return { "torch_version": self.torch_version, "cuda_available": self.cuda_available, "cuda_device_name": self.cuda_device_name, "bf16_supported": self.bf16_supported, "torch_compile_available": self.torch_compile_available, "triton_available": self.triton_available, "deepspeed_available": self.deepspeed_available, "bitsandbytes_available": self.bitsandbytes_available, "flash_attn_available": self.flash_attn_available, "nvcc_available": self.nvcc_available, } def _has_module(name: str) -> bool: return importlib.util.find_spec(name) is not None def detect_xqs_backends() -> XQSBackendReport: cuda_available = torch.cuda.is_available() device_name = torch.cuda.get_device_name(0) if cuda_available else "cpu" bf16_supported = bool(cuda_available and torch.cuda.is_bf16_supported()) return XQSBackendReport( torch_version=torch.__version__, cuda_available=cuda_available, cuda_device_name=device_name, bf16_supported=bf16_supported, torch_compile_available=hasattr(torch, "compile"), triton_available=_has_module("triton"), deepspeed_available=_has_module("deepspeed"), bitsandbytes_available=_has_module("bitsandbytes"), flash_attn_available=_has_module("flash_attn"), nvcc_available=shutil.which("nvcc") is not None, ) def choose_attention_backend(prefer_flash: bool = True) -> str: report = detect_xqs_backends() if prefer_flash and report.flash_attn_available and report.cuda_available: return "flash_attn" if report.cuda_available: return "scaled_dot_product_attention" return "eager" def choose_optimizer_backend(prefer_low_memory: bool = True) -> str: report = detect_xqs_backends() adamw_signature = getattr(torch.optim.AdamW, "__init__", None) fused_supported = bool(adamw_signature and "fused" in adamw_signature.__code__.co_varnames) if report.cuda_available and fused_supported: return "adamw_fused" if prefer_low_memory and report.bitsandbytes_available: return "adam8bit" if _has_module("transformers"): return "adafactor" return "sgd" def choose_moe_backend(prefer_deepspeed: bool = True) -> str: report = detect_xqs_backends() if prefer_deepspeed and report.deepspeed_available and report.cuda_available: return "deepspeed" return "native" def choose_quant_backend(prefer_triton: bool = True) -> str: report = detect_xqs_backends() if prefer_triton and report.triton_available and report.cuda_available: return "triton" return "pytorch" def format_backend_report(report: XQSBackendReport) -> str: ordered = report.as_dict() return "\n".join(f"{key}={value}" for key, value in ordered.items())