RubiRLM-1B-Base / xqs_stack.py
DevHunterAI's picture
Upload folder using huggingface_hub
cd16f07 verified
from __future__ import annotations
import importlib.util
import shutil
from dataclasses import dataclass
from typing import Dict
import torch
@dataclass(frozen=True)
class XQSBackendReport:
torch_version: str
cuda_available: bool
cuda_device_name: str
bf16_supported: bool
torch_compile_available: bool
triton_available: bool
deepspeed_available: bool
bitsandbytes_available: bool
flash_attn_available: bool
nvcc_available: bool
def as_dict(self) -> Dict[str, object]:
return {
"torch_version": self.torch_version,
"cuda_available": self.cuda_available,
"cuda_device_name": self.cuda_device_name,
"bf16_supported": self.bf16_supported,
"torch_compile_available": self.torch_compile_available,
"triton_available": self.triton_available,
"deepspeed_available": self.deepspeed_available,
"bitsandbytes_available": self.bitsandbytes_available,
"flash_attn_available": self.flash_attn_available,
"nvcc_available": self.nvcc_available,
}
def _has_module(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def detect_xqs_backends() -> XQSBackendReport:
cuda_available = torch.cuda.is_available()
device_name = torch.cuda.get_device_name(0) if cuda_available else "cpu"
bf16_supported = bool(cuda_available and torch.cuda.is_bf16_supported())
return XQSBackendReport(
torch_version=torch.__version__,
cuda_available=cuda_available,
cuda_device_name=device_name,
bf16_supported=bf16_supported,
torch_compile_available=hasattr(torch, "compile"),
triton_available=_has_module("triton"),
deepspeed_available=_has_module("deepspeed"),
bitsandbytes_available=_has_module("bitsandbytes"),
flash_attn_available=_has_module("flash_attn"),
nvcc_available=shutil.which("nvcc") is not None,
)
def choose_attention_backend(prefer_flash: bool = True) -> str:
report = detect_xqs_backends()
if prefer_flash and report.flash_attn_available and report.cuda_available:
return "flash_attn"
if report.cuda_available:
return "scaled_dot_product_attention"
return "eager"
def choose_optimizer_backend(prefer_low_memory: bool = True) -> str:
report = detect_xqs_backends()
adamw_signature = getattr(torch.optim.AdamW, "__init__", None)
fused_supported = bool(adamw_signature and "fused" in adamw_signature.__code__.co_varnames)
if report.cuda_available and fused_supported:
return "adamw_fused"
if prefer_low_memory and report.bitsandbytes_available:
return "adam8bit"
if _has_module("transformers"):
return "adafactor"
return "sgd"
def choose_moe_backend(prefer_deepspeed: bool = True) -> str:
report = detect_xqs_backends()
if prefer_deepspeed and report.deepspeed_available and report.cuda_available:
return "deepspeed"
return "native"
def choose_quant_backend(prefer_triton: bool = True) -> str:
report = detect_xqs_backends()
if prefer_triton and report.triton_available and report.cuda_available:
return "triton"
return "pytorch"
def format_backend_report(report: XQSBackendReport) -> str:
ordered = report.as_dict()
return "\n".join(f"{key}={value}" for key, value in ordered.items())