""" Chimera 5.2 — 1.58-bit Ternary Compute (CPU-First, Slim) ======================================================== Single, clean implementation of BitNet-1.58 ternary linear layers. Design goals: * Zero overhead at import time (no JIT, no kernel discovery). * One fast pure-PyTorch path that vectorises everything; an optional C++/OpenMP path that is loaded *lazily* and only used when it actually beats PyTorch (small batches on inference). * Cache the packed 2-bit weights between forward calls and only repack when the latent FP32 weights are mutated (training step or MeZO). * No data-dependent Python loops, no per-row mask construction at init. * torch.compile compatible: STE uses detach() trick (zero graph breaks). Storage: weight: FP32 latent of shape [M, K] (kept for STE backward / MeZO updates) _packed: uint8 [M, ceil(K/4)] (2 bits per ternary value) _alpha: fp32 [M] (per-row absolute mean scale) Encoding (matches the C++ kernel): -1 → 0b10 0 → 0b00 +1 → 0b01 """ from __future__ import annotations import math import os import threading from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # Lazy C++ kernel. # --------------------------------------------------------------------------- _NATIVE_LOCK = threading.Lock() _NATIVE_EXT: Optional[object] = None _NATIVE_TRIED = False _CPP_SOURCE = r""" #include #include #include #ifdef _OPENMP #include #endif static const float LUT[4] = {0.0f, 1.0f, -1.0f, 0.0f}; torch::Tensor pack_ternary_cpu(torch::Tensor w) { TORCH_CHECK(w.dim() == 2 && w.dtype() == torch::kInt8, "expected int8 [M,K]"); auto w_c = w.contiguous(); int64_t M = w_c.size(0), K = w_c.size(1); int64_t K4 = (K + 3) / 4; auto out = torch::zeros({M, K4}, torch::kUInt8); const int8_t* s = w_c.data_ptr(); uint8_t* d = out.data_ptr(); #pragma omp parallel for schedule(static) for (int64_t m = 0; m < M; ++m) { const int8_t* sr = s + m * K; uint8_t* dr = d + m * K4; for (int64_t k4 = 0; k4 < K4; ++k4) { uint8_t b = 0; for (int j = 0; j < 4; ++j) { int64_t k = k4 * 4 + j; if (k >= K) break; int8_t v = sr[k]; uint8_t code = (v == 1) ? 1u : (v == -1 ? 2u : 0u); b |= (code << (6 - j * 2)); } dr[k4] = b; } } return out; } torch::Tensor unpack_ternary_cpu(torch::Tensor packed, int64_t K) { TORCH_CHECK(packed.dim() == 2 && packed.dtype() == torch::kUInt8, "expected uint8 [M,K4]"); auto p = packed.contiguous(); int64_t M = p.size(0), K4 = p.size(1); auto out = torch::empty({M, K}, torch::kFloat32); const uint8_t* pp = p.data_ptr(); float* dp = out.data_ptr(); #pragma omp parallel for schedule(static) for (int64_t m = 0; m < M; ++m) { const uint8_t* pr = pp + m * K4; float* dr = dp + m * K; for (int64_t k4 = 0; k4 < K4; ++k4) { uint8_t b = pr[k4]; int64_t base = k4 * 4; if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3]; if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3]; if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3]; if (base + 3 < K) dr[base + 3] = LUT[b & 3]; } } return out; } torch::Tensor dequantize_cpu(torch::Tensor packed, torch::Tensor alpha, int64_t K) { auto p = packed.contiguous(); auto a = alpha.contiguous().to(torch::kFloat32); int64_t M = p.size(0), K4 = p.size(1); auto out = torch::empty({M, K}, torch::kFloat32); const uint8_t* pp = p.data_ptr(); const float* ap = a.data_ptr(); float* dp = out.data_ptr(); #pragma omp parallel for schedule(static) for (int64_t m = 0; m < M; ++m) { const uint8_t* pr = pp + m * K4; float* dr = dp + m * K; float sc = ap[m]; for (int64_t k4 = 0; k4 < K4; ++k4) { uint8_t b = pr[k4]; int64_t base = k4 * 4; if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3] * sc; if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3] * sc; if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3] * sc; if (base + 3 < K) dr[base + 3] = LUT[b & 3] * sc; } } return out; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("pack_ternary", &pack_ternary_cpu, "Pack int8 ternary -> 2-bit uint8"); m.def("unpack_ternary", &unpack_ternary_cpu, "Unpack 2-bit uint8 -> fp32 {-1,0,1}"); m.def("dequantize", &dequantize_cpu, "Unpack and scale by per-row alpha"); } """ def _try_load_native() -> Optional[object]: global _NATIVE_EXT, _NATIVE_TRIED if _NATIVE_TRIED: return _NATIVE_EXT with _NATIVE_LOCK: if _NATIVE_TRIED: return _NATIVE_EXT _NATIVE_TRIED = True try: from torch.utils.cpp_extension import load_inline build_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", ".ternary_build" ) os.makedirs(build_dir, exist_ok=True) _NATIVE_EXT = load_inline( name="chimera_ternary", cpp_sources=_CPP_SOURCE, extra_cflags=["-O3", "-fopenmp", "-ffast-math", "-funroll-loops"], extra_ldflags=["-lgomp"], build_directory=build_dir, verbose=False, ) except Exception as exc: os.environ.setdefault("CHIMERA_NATIVE_DISABLED", str(exc)[:200]) _NATIVE_EXT = None return _NATIVE_EXT def enable_native_kernel(force: bool = False) -> bool: global _NATIVE_TRIED if force: _NATIVE_TRIED = False return _try_load_native() is not None def native_kernel_available() -> bool: return _NATIVE_EXT is not None if os.environ.get("CHIMERA_NATIVE", "0") == "1": enable_native_kernel() # --------------------------------------------------------------------------- # Pure PyTorch ternary primitives. # --------------------------------------------------------------------------- _TERNARY_LUT_F32 = torch.tensor([0.0, 1.0, -1.0, 0.0], dtype=torch.float32) _TERNARY_LUT_I8 = torch.tensor([0, 1, -1, 0], dtype=torch.int8) _SHIFTS = torch.tensor([6, 4, 2, 0], dtype=torch.uint8) def pack_ternary(q: torch.Tensor) -> torch.Tensor: q = q.detach() if q.dim() == 1: q = q.unsqueeze(0) flat = q.reshape(-1, q.shape[-1]).to(torch.int8) M, K = flat.shape K4 = (K + 3) // 4 pad = K4 * 4 - K if pad: flat = F.pad(flat, (0, pad)) codes = torch.where(flat == 1, torch.full_like(flat, 1), torch.where(flat == -1, torch.full_like(flat, 2), torch.zeros_like(flat))).to(torch.uint8) codes = codes.view(M, K4, 4) packed = ((codes[..., 0] << 6) | (codes[..., 1] << 4) | (codes[..., 2] << 2) | codes[..., 3]).contiguous() return packed.reshape(*q.shape[:-1], K4) def unpack_ternary(packed: torch.Tensor, k: int, alpha: Optional[torch.Tensor] = None, dtype: torch.dtype = torch.float32) -> torch.Tensor: packed = packed.to(torch.uint8) if packed.dim() == 1: packed = packed.unsqueeze(0) flat = packed.reshape(-1, packed.shape[-1]) M, K4 = flat.shape shifts = _SHIFTS.to(packed.device) codes = (flat.unsqueeze(-1) >> shifts).bitwise_and_(3).to(torch.long) lut = _TERNARY_LUT_F32.to(device=packed.device, dtype=dtype) out = lut[codes].reshape(M, K4 * 4)[:, :k] if alpha is not None: out = out * alpha.reshape(M, 1).to(device=out.device, dtype=out.dtype) return out.reshape(*packed.shape[:-1], k) def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32) def ternarize_weight(weight: torch.Tensor, group_size: int = 128 ) -> Tuple[torch.Tensor, torch.Tensor]: alpha = _absmean_alpha(weight) w_q = torch.round(torch.clamp(weight / alpha.unsqueeze(-1), -1.0, 1.0)).to(torch.int8) return w_q, alpha _quantize_weights_ternary = ternarize_weight def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor: with torch.no_grad(): last = weight.shape[-1] pad = (-last) % 4 target = F.pad(weight, (0, pad)) if pad else weight view = target.view(*target.shape[:-1], -1, 4) idx = view.abs().argsort(dim=-1)[..., :2] view.scatter_(-1, idx, 0.0) if pad: weight.copy_(target[..., :last]) return weight # --------------------------------------------------------------------------- # Straight-Through Estimator for ternary quantization. # --------------------------------------------------------------------------- # # CLAMP-AWARE STE using the detach() trick: # # clamped = clamp(w, -1, 1) # w_q = clamped + (round(clamped) - clamped).detach() # # Forward: evaluates to round(clamp(w, -1, 1)) — same as before. # Backward: ∂/∂w [clamp(w, -1, 1)] = 1 if |w| <= 1 else 0. # → Gradients are ZERO for weights outside [-1, 1] (at quantization boundary). # → Gradients pass through unchanged inside [-1, 1] (STE identity). # # This prevents gradient explosion that caused NaN at step ~150 with the # pure identity STE (w + (quant - w).detach()). The clamp derivative acts # as a natural gradient gate: weights that have drifted beyond the ternary # range get no gradient push, preventing runaway accumulation. # # Ref: 4-bit CPU training (arxiv:2603.13931) uses tanh soft clipping for # the same stabilization purpose. # --------------------------------------------------------------------------- class _RoundTernarySTE(torch.autograd.Function): """LEGACY — kept for backward compat. Use ste_ternary() instead.""" @staticmethod def forward(ctx, w: torch.Tensor) -> torch.Tensor: return torch.round(torch.clamp(w, -1.0, 1.0)) @staticmethod def backward(ctx, grad_output: torch.Tensor): return grad_output.clamp(-1.0, 1.0) def ste_ternary(w: torch.Tensor) -> torch.Tensor: """Straight-through estimator for ternary quantization. Forward: round(clamp(w, -1, 1)) Backward: clamp derivative (zero outside [-1, 1], identity inside) Uses the detach() trick for zero graph breaks under torch.compile. """ clamped = torch.clamp(w, -1.0, 1.0) w_q = torch.round(clamped) return clamped + (w_q - clamped).detach() # --------------------------------------------------------------------------- # BitLinear # --------------------------------------------------------------------------- class BitLinear(nn.Module): """Linear layer with ternary {-1, 0, 1} weights and per-row absmean scale. *Training*: STE ternarisation with clamp-aware gradient gating. *Inference*: cached packed 2-bit uint8 weights. """ __constants__ = ["in_features", "out_features", "use_2_4"] def __init__(self, in_features: int, out_features: int, bias: bool = False, group_size: int = 128, nm_2_4: bool = False): super().__init__() self.in_features = int(in_features) self.out_features = int(out_features) self.group_size = int(group_size) self.use_2_4 = bool(nm_2_4) self.weight = nn.Parameter(torch.empty(self.out_features, self.in_features)) if bias: self.bias = nn.Parameter(torch.zeros(self.out_features)) else: self.register_parameter("bias", None) self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False) self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False) self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False) self._packed_version = -1 self._dense_version = -1 self._cache_version = 0 self.reset_parameters() def reset_parameters(self) -> None: nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: nn.init.zeros_(self.bias) self._cache_version += 1 def invalidate_packed(self) -> None: self._cache_version += 1 if self._dense_w.numel() > 0: self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device) self._dense_version = -1 def _quantize_latent(self) -> Tuple[torch.Tensor, torch.Tensor]: with torch.no_grad(): w = self.weight alpha = _absmean_alpha(w) w_q = torch.round(torch.clamp(w / alpha.unsqueeze(-1), -1.0, 1.0)) if self.use_2_4: apply_2_4_sparsity_(w_q) return w_q.to(torch.int8), alpha def _ensure_packed(self) -> None: if self._packed_version == self._cache_version and self._packed.numel() > 0: return with torch.no_grad(): w_q, alpha = self._quantize_latent() ext = _NATIVE_EXT if ext is not None: packed = ext.pack_ternary(w_q) else: packed = pack_ternary(w_q) self._packed = packed.contiguous() self._alpha = alpha.contiguous() self._packed_version = self._cache_version @torch.no_grad() def prepare_for_inference(self) -> None: self.invalidate_packed() self._ensure_packed() @torch.no_grad() def ternary_nonzero_mask(self) -> torch.Tensor: self._ensure_packed() ext = _NATIVE_EXT if ext is not None: w = ext.unpack_ternary(self._packed, self.in_features) else: w = unpack_ternary(self._packed, self.in_features) return w.ne(0) def _forward_train(self, x: torch.Tensor) -> torch.Tensor: """STE forward with clamp-aware gradient gating. The clamp on w_scaled ensures: - Forward: round(clamp(w/alpha, -1, 1)) * alpha — correct ternary - Backward: gradient is ZERO for w_scaled outside [-1, 1], preventing gradient explosion from weights at the boundary. """ w = self.weight alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5) w_scaled = w / alpha # Clamp FIRST, then detach the rounding residual. # Gradient of clamp: 1 inside [-1,1], 0 outside → natural gradient gate clamped = torch.clamp(w_scaled, -1.0, 1.0) w_q = clamped + (torch.round(clamped) - clamped).detach() w_q = w_q * alpha if self.use_2_4: with torch.no_grad(): mask = (apply_2_4_sparsity_(w_q.detach().clone()) != 0).to(w_q.dtype) w_q = w_q * mask return F.linear(x, w_q.to(x.dtype), self.bias) def _ensure_dense(self) -> torch.Tensor: self._ensure_packed() if self._dense_version == self._cache_version and self._dense_w.numel() > 0: return self._dense_w ext = _NATIVE_EXT if ext is not None: w = ext.dequantize(self._packed, self._alpha, self.in_features) else: w = unpack_ternary(self._packed, self.in_features) * self._alpha.unsqueeze(-1) self._dense_w = w.contiguous() self._dense_version = self._cache_version return self._dense_w def _forward_packed(self, x: torch.Tensor) -> torch.Tensor: w = self._ensure_dense() if x.dtype != w.dtype: w_used = w.to(x.dtype) else: w_used = w return F.linear(x, w_used, self.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training and torch.is_grad_enabled(): return self._forward_train(x) return self._forward_packed(x) def extra_repr(self) -> str: return (f"in_features={self.in_features}, out_features={self.out_features}, " f"bias={self.bias is not None}, nm_2_4={self.use_2_4}, " f"native={native_kernel_available()}") # --------------------------------------------------------------------------- # RMSNorm # --------------------------------------------------------------------------- class RMSNorm(nn.Module): __constants__ = ["dim", "eps"] def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.dim = int(dim) self.eps = float(eps) self.weight = nn.Parameter(torch.ones(self.dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: dtype = x.dtype if dtype != torch.float32: x32 = x.float() rms = torch.rsqrt(x32.pow(2).mean(dim=-1, keepdim=True).add(self.eps)) return (x32 * rms).to(dtype) * self.weight rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True).add(self.eps)) return x * rms * self.weight __all__ = [ "BitLinear", "RMSNorm", "ste_ternary", "pack_ternary", "unpack_ternary", "ternarize_weight", "_quantize_weights_ternary", "apply_2_4_sparsity_", "enable_native_kernel", "native_kernel_available", ]