from __future__ import annotations import importlib.util from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from xqs_stack import choose_attention_backend, choose_quant_backend from xqs_triton_ops import triton_ternary_linear _HAS_FLASH_ATTN = importlib.util.find_spec("flash_attn") is not None if _HAS_FLASH_ATTN: from flash_attn import flash_attn_func _ATTN_BACKEND = choose_attention_backend(prefer_flash=True) _QUANT_BACKEND = choose_quant_backend(prefer_triton=True) def ternary_quantize(weight: torch.Tensor) -> torch.Tensor: scale = weight.detach().abs().mean().clamp(min=1e-6) pos = weight > (0.5 * scale) neg = weight < (-0.5 * scale) quantized = torch.zeros_like(weight) quantized = torch.where(pos, torch.ones_like(weight), quantized) quantized = torch.where(neg, -torch.ones_like(weight), quantized) quantized = quantized * scale return weight + (quantized - weight).detach() class TernaryLinear(nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = True): super().__init__() self.in_features = in_features self.out_features = out_features self.backend = _QUANT_BACKEND self.weight = nn.Parameter(torch.empty(out_features, in_features)) if bias: self.bias = nn.Parameter(torch.empty(out_features)) else: self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5) if self.bias is not None: bound = 1 / max(1, self.in_features) ** 0.5 nn.init.uniform_(self.bias, -bound, bound) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.backend == "triton": return triton_ternary_linear(x, self.weight, self.bias) return F.linear(x, ternary_quantize(self.weight), self.bias) def build_linear(in_features: int, out_features: int, bias: bool = True, ternary: bool = False) -> nn.Module: if ternary: return TernaryLinear(in_features, out_features, bias=bias) return nn.Linear(in_features, out_features, bias=bias) def fused_residual_add(x: torch.Tensor, residual: torch.Tensor, gate: Optional[torch.Tensor] = None) -> torch.Tensor: if gate is None: return x + residual return x + (gate * residual) def causal_scaled_dot_product_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float = 0.0, training: bool = False, ) -> torch.Tensor: if _ATTN_BACKEND == "flash_attn" and _HAS_FLASH_ATTN and q.is_cuda and q.dtype in {torch.float16, torch.bfloat16}: q_flash = q.transpose(1, 2).contiguous() k_flash = k.transpose(1, 2).contiguous() v_flash = v.transpose(1, 2).contiguous() out = flash_attn_func( q_flash, k_flash, v_flash, dropout_p=dropout_p if training else 0.0, causal=True, ) return out.transpose(1, 2).contiguous() if hasattr(F, "scaled_dot_product_attention"): return F.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=dropout_p if training else 0.0, is_causal=True, ) scale = q.size(-1) ** -0.5 scores = torch.matmul(q, k.transpose(-2, -1)) * scale causal_mask = torch.triu(torch.ones(scores.size(-2), scores.size(-1), device=scores.device, dtype=torch.bool), diagonal=1) scores = scores.masked_fill(causal_mask, float("-inf")) probs = torch.softmax(scores, dim=-1) if training and dropout_p > 0: probs = F.dropout(probs, p=dropout_p) return torch.matmul(probs, v) def pack_rows(indices: torch.Tensor, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: return tuple(t.index_select(0, indices) for t in tensors) def scatter_rows(base: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: if indices.numel() == 0: return base out = base.clone() out.index_copy_(0, indices, updates) return out def maybe_compile_module(module: nn.Module, enabled: bool) -> nn.Module: if not enabled: return module compile_fn = getattr(torch, "compile", None) if compile_fn is None: return module try: return compile_fn(module) except Exception: return module