| from __future__ import annotations |
|
|
| import importlib.util |
| import shutil |
| from dataclasses import dataclass |
| from typing import Dict |
|
|
| import torch |
|
|
|
|
| @dataclass(frozen=True) |
| class XQSBackendReport: |
| torch_version: str |
| cuda_available: bool |
| cuda_device_name: str |
| bf16_supported: bool |
| torch_compile_available: bool |
| triton_available: bool |
| deepspeed_available: bool |
| bitsandbytes_available: bool |
| flash_attn_available: bool |
| nvcc_available: bool |
|
|
| def as_dict(self) -> Dict[str, object]: |
| return { |
| "torch_version": self.torch_version, |
| "cuda_available": self.cuda_available, |
| "cuda_device_name": self.cuda_device_name, |
| "bf16_supported": self.bf16_supported, |
| "torch_compile_available": self.torch_compile_available, |
| "triton_available": self.triton_available, |
| "deepspeed_available": self.deepspeed_available, |
| "bitsandbytes_available": self.bitsandbytes_available, |
| "flash_attn_available": self.flash_attn_available, |
| "nvcc_available": self.nvcc_available, |
| } |
|
|
|
|
| def _has_module(name: str) -> bool: |
| return importlib.util.find_spec(name) is not None |
|
|
|
|
|
|
| def detect_xqs_backends() -> XQSBackendReport: |
| cuda_available = torch.cuda.is_available() |
| device_name = torch.cuda.get_device_name(0) if cuda_available else "cpu" |
| bf16_supported = bool(cuda_available and torch.cuda.is_bf16_supported()) |
| return XQSBackendReport( |
| torch_version=torch.__version__, |
| cuda_available=cuda_available, |
| cuda_device_name=device_name, |
| bf16_supported=bf16_supported, |
| torch_compile_available=hasattr(torch, "compile"), |
| triton_available=_has_module("triton"), |
| deepspeed_available=_has_module("deepspeed"), |
| bitsandbytes_available=_has_module("bitsandbytes"), |
| flash_attn_available=_has_module("flash_attn"), |
| nvcc_available=shutil.which("nvcc") is not None, |
| ) |
|
|
|
|
|
|
| def choose_attention_backend(prefer_flash: bool = True) -> str: |
| report = detect_xqs_backends() |
| if prefer_flash and report.flash_attn_available and report.cuda_available: |
| return "flash_attn" |
| if report.cuda_available: |
| return "scaled_dot_product_attention" |
| return "eager" |
|
|
|
|
|
|
| def choose_optimizer_backend(prefer_low_memory: bool = True) -> str: |
| report = detect_xqs_backends() |
| adamw_signature = getattr(torch.optim.AdamW, "__init__", None) |
| fused_supported = bool(adamw_signature and "fused" in adamw_signature.__code__.co_varnames) |
| if report.cuda_available and fused_supported: |
| return "adamw_fused" |
| if prefer_low_memory and report.bitsandbytes_available: |
| return "adam8bit" |
| if _has_module("transformers"): |
| return "adafactor" |
| return "sgd" |
|
|
|
|
|
|
| def choose_moe_backend(prefer_deepspeed: bool = True) -> str: |
| report = detect_xqs_backends() |
| if prefer_deepspeed and report.deepspeed_available and report.cuda_available: |
| return "deepspeed" |
| return "native" |
|
|
|
|
|
|
| def choose_quant_backend(prefer_triton: bool = True) -> str: |
| report = detect_xqs_backends() |
| if prefer_triton and report.triton_available and report.cuda_available: |
| return "triton" |
| return "pytorch" |
|
|
|
|
|
|
| def format_backend_report(report: XQSBackendReport) -> str: |
| ordered = report.as_dict() |
| return "\n".join(f"{key}={value}" for key, value in ordered.items()) |
|
|