""" Chimera 5.2 — 1.58-bit Ternary Compute (CPU-First, Slim) ======================================================== Single, clean implementation of BitNet-1.58 ternary linear layers. Design goals: * Zero overhead at import time (no JIT, no kernel discovery). * One fast pure-PyTorch path that vectorises everything; an optional C++/OpenMP path that is loaded *lazily* and only used when it actually beats PyTorch (small batches on inference). * Cache the packed 2-bit weights between forward calls and only repack when the latent FP32 weights are mutated (training step or MeZO). * No data-dependent Python loops, no per-row mask construction at init. Storage: weight: FP32 latent of shape [M, K] (kept for STE backward / MeZO updates) _packed: uint8 [M, ceil(K/4)] (2 bits per ternary value) _alpha: fp32 [M] (per-row absolute mean scale) Encoding (matches the C++ kernel): -1 → 0b10 0 → 0b00 +1 → 0b01 """ from __future__ import annotations import math import os import threading from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # Lazy C++ kernel. We never compile it during ``import``; it is only built # when explicitly requested via :func:`enable_native_kernel` or the env var # ``CHIMERA_NATIVE=1``. All public APIs work with the pure-PyTorch path. # --------------------------------------------------------------------------- _NATIVE_LOCK = threading.Lock() _NATIVE_EXT: Optional[object] = None _NATIVE_TRIED = False _CPP_SOURCE = r""" #include #include #include #ifdef _OPENMP #include #endif // Encoding: -1->0b10, 0->0b00, +1->0b01 static const float LUT[4] = {0.0f, 1.0f, -1.0f, 0.0f}; torch::Tensor pack_ternary_cpu(torch::Tensor w) { TORCH_CHECK(w.dim() == 2 && w.dtype() == torch::kInt8, "expected int8 [M,K]"); auto w_c = w.contiguous(); int64_t M = w_c.size(0), K = w_c.size(1); int64_t K4 = (K + 3) / 4; auto out = torch::zeros({M, K4}, torch::kUInt8); const int8_t* s = w_c.data_ptr(); uint8_t* d = out.data_ptr(); #pragma omp parallel for schedule(static) for (int64_t m = 0; m < M; ++m) { const int8_t* sr = s + m * K; uint8_t* dr = d + m * K4; for (int64_t k4 = 0; k4 < K4; ++k4) { uint8_t b = 0; for (int j = 0; j < 4; ++j) { int64_t k = k4 * 4 + j; if (k >= K) break; int8_t v = sr[k]; uint8_t code = (v == 1) ? 1u : (v == -1 ? 2u : 0u); b |= (code << (6 - j * 2)); } dr[k4] = b; } } return out; } torch::Tensor unpack_ternary_cpu(torch::Tensor packed, int64_t K) { TORCH_CHECK(packed.dim() == 2 && packed.dtype() == torch::kUInt8, "expected uint8 [M,K4]"); auto p = packed.contiguous(); int64_t M = p.size(0), K4 = p.size(1); auto out = torch::empty({M, K}, torch::kFloat32); const uint8_t* pp = p.data_ptr(); float* dp = out.data_ptr(); #pragma omp parallel for schedule(static) for (int64_t m = 0; m < M; ++m) { const uint8_t* pr = pp + m * K4; float* dr = dp + m * K; for (int64_t k4 = 0; k4 < K4; ++k4) { uint8_t b = pr[k4]; int64_t base = k4 * 4; if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3]; if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3]; if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3]; if (base + 3 < K) dr[base + 3] = LUT[b & 3]; } } return out; } // Fused "unpack and scale" -> bf16/fp32 dense weight. Saves a pass over memory // and a temporary FP32 tensor when running under bf16 autocast. torch::Tensor dequantize_cpu(torch::Tensor packed, torch::Tensor alpha, int64_t K) { auto p = packed.contiguous(); auto a = alpha.contiguous().to(torch::kFloat32); int64_t M = p.size(0), K4 = p.size(1); auto out = torch::empty({M, K}, torch::kFloat32); const uint8_t* pp = p.data_ptr(); const float* ap = a.data_ptr(); float* dp = out.data_ptr(); #pragma omp parallel for schedule(static) for (int64_t m = 0; m < M; ++m) { const uint8_t* pr = pp + m * K4; float* dr = dp + m * K; float sc = ap[m]; for (int64_t k4 = 0; k4 < K4; ++k4) { uint8_t b = pr[k4]; int64_t base = k4 * 4; if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3] * sc; if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3] * sc; if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3] * sc; if (base + 3 < K) dr[base + 3] = LUT[b & 3] * sc; } } return out; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("pack_ternary", &pack_ternary_cpu, "Pack int8 ternary -> 2-bit uint8"); m.def("unpack_ternary", &unpack_ternary_cpu, "Unpack 2-bit uint8 -> fp32 {-1,0,1}"); m.def("dequantize", &dequantize_cpu, "Unpack and scale by per-row alpha"); } """ def _try_load_native() -> Optional[object]: """Compile/load the optional native helper. Idempotent and thread-safe.""" global _NATIVE_EXT, _NATIVE_TRIED if _NATIVE_TRIED: return _NATIVE_EXT with _NATIVE_LOCK: if _NATIVE_TRIED: return _NATIVE_EXT _NATIVE_TRIED = True try: from torch.utils.cpp_extension import load_inline build_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", ".ternary_build" ) os.makedirs(build_dir, exist_ok=True) _NATIVE_EXT = load_inline( name="chimera_ternary", cpp_sources=_CPP_SOURCE, extra_cflags=["-O3", "-fopenmp", "-ffast-math", "-funroll-loops"], extra_ldflags=["-lgomp"], build_directory=build_dir, verbose=False, ) except Exception as exc: # pragma: no cover - best-effort. os.environ.setdefault("CHIMERA_NATIVE_DISABLED", str(exc)[:200]) _NATIVE_EXT = None return _NATIVE_EXT def enable_native_kernel(force: bool = False) -> bool: """Eagerly try to compile the native kernel. Returns ``True`` if the kernel is loaded and available. """ global _NATIVE_TRIED if force: _NATIVE_TRIED = False return _try_load_native() is not None def native_kernel_available() -> bool: return _NATIVE_EXT is not None # Allow opt-in from the environment without code changes. if os.environ.get("CHIMERA_NATIVE", "0") == "1": enable_native_kernel() # --------------------------------------------------------------------------- # Pure PyTorch ternary primitives (always available). # --------------------------------------------------------------------------- # Lookup tables compiled once. Casting to a registered buffer is overkill – # they live on CPU and broadcast naturally. _TERNARY_LUT_F32 = torch.tensor([0.0, 1.0, -1.0, 0.0], dtype=torch.float32) _TERNARY_LUT_I8 = torch.tensor([0, 1, -1, 0], dtype=torch.int8) _SHIFTS = torch.tensor([6, 4, 2, 0], dtype=torch.uint8) def pack_ternary(q: torch.Tensor) -> torch.Tensor: """Pack a ternary {-1,0,1} tensor into a 2-bit uint8 tensor. Vectorised pure-PyTorch implementation — no Python loops over rows. Trailing positions that don't divide by four are zero-padded. """ q = q.detach() if q.dim() == 1: q = q.unsqueeze(0) flat = q.reshape(-1, q.shape[-1]).to(torch.int8) M, K = flat.shape K4 = (K + 3) // 4 pad = K4 * 4 - K if pad: flat = F.pad(flat, (0, pad)) # codes: 0 / 1 / 2 (uint8) codes = torch.where(flat == 1, torch.full_like(flat, 1), torch.where(flat == -1, torch.full_like(flat, 2), torch.zeros_like(flat))).to(torch.uint8) codes = codes.view(M, K4, 4) packed = ((codes[..., 0] << 6) | (codes[..., 1] << 4) | (codes[..., 2] << 2) | codes[..., 3]).contiguous() return packed.reshape(*q.shape[:-1], K4) def unpack_ternary(packed: torch.Tensor, k: int, alpha: Optional[torch.Tensor] = None, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Vectorised inverse of :func:`pack_ternary`. Returns ``out`` with last dim ``k``; optionally pre-multiplied by ``alpha`` (per-row scale, broadcastable on the leading axes). """ packed = packed.to(torch.uint8) if packed.dim() == 1: packed = packed.unsqueeze(0) flat = packed.reshape(-1, packed.shape[-1]) M, K4 = flat.shape # Gather all 4 sub-positions in one vectorised op. shifts = _SHIFTS.to(packed.device) codes = (flat.unsqueeze(-1) >> shifts).bitwise_and_(3).to(torch.long) # [M, K4, 4] lut = _TERNARY_LUT_F32.to(device=packed.device, dtype=dtype) out = lut[codes].reshape(M, K4 * 4)[:, :k] if alpha is not None: out = out * alpha.reshape(M, 1).to(device=out.device, dtype=out.dtype) return out.reshape(*packed.shape[:-1], k) def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: """Per-output-channel scale (``\alpha = mean|w|`` clamped).""" return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32) def ternarize_weight(weight: torch.Tensor, group_size: int = 128 ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantise FP32 weights to ternary using BitNet's abs-mean rule. ``group_size`` is kept for API compatibility but every row is its own group in this slim implementation. Returns ``(w_ternary, alpha)``. """ alpha = _absmean_alpha(weight) w_q = torch.round(torch.clamp(weight / alpha.unsqueeze(-1), -1.0, 1.0)).to(torch.int8) return w_q, alpha _quantize_weights_ternary = ternarize_weight # legacy alias used elsewhere def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor: """In-place N:M 2:4 pruning. Vectorised — no Python row loops.""" with torch.no_grad(): last = weight.shape[-1] pad = (-last) % 4 target = F.pad(weight, (0, pad)) if pad else weight view = target.view(*target.shape[:-1], -1, 4) # Keep the two largest in absolute value, zero the rest. idx = view.abs().argsort(dim=-1)[..., :2] view.scatter_(-1, idx, 0.0) if pad: weight.copy_(target[..., :last]) return weight # --------------------------------------------------------------------------- # Straight-Through Estimator for ternary quantization. # --------------------------------------------------------------------------- class _RoundTernarySTE(torch.autograd.Function): @staticmethod def forward(ctx, w: torch.Tensor) -> torch.Tensor: # type: ignore[override] return torch.round(torch.clamp(w, -1.0, 1.0)) @staticmethod def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] # Standard STE: gradient flows through, clipped to [-1, 1] so the # latent FP32 weights cannot drift unboundedly. return grad_output.clamp(-1.0, 1.0) def ste_ternary(w: torch.Tensor) -> torch.Tensor: return _RoundTernarySTE.apply(w) # --------------------------------------------------------------------------- # BitLinear — single class, single fast path. # --------------------------------------------------------------------------- class BitLinear(nn.Module): """Linear layer with ternary {-1, 0, 1} weights and per-row absmean scale. *Training (grad-enabled)*: STE ternarisation on the latent weight, dense fp32/bf16 matmul. Backward flows to the latent weight via STE. *Inference / no-grad*: weights are quantised once and cached as packed 2-bit uint8 + fp32 alpha. Each forward unpacks (vectorised PyTorch or optional C++ kernel) into a reusable buffer and calls a single matmul. """ __constants__ = ["in_features", "out_features", "use_2_4"] def __init__(self, in_features: int, out_features: int, bias: bool = False, group_size: int = 128, nm_2_4: bool = False): super().__init__() self.in_features = int(in_features) self.out_features = int(out_features) self.group_size = int(group_size) self.use_2_4 = bool(nm_2_4) self.weight = nn.Parameter(torch.empty(self.out_features, self.in_features)) if bias: self.bias = nn.Parameter(torch.zeros(self.out_features)) else: self.register_parameter("bias", None) # Caches. ``_cache_version`` is bumped whenever the latent weight # changes; the forward pass compares it against ``_packed_version`` # to know when to repack. self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False) self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False) # Optional dense fp32 cache of the dequantised ternary weight. This # is what every inference forward actually needs, so caching it # eliminates the per-call unpack and saves ~30-50% of CPU time on # small models. It is only built lazily on first inference call. self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False) self._packed_version = -1 self._dense_version = -1 self._cache_version = 0 self.reset_parameters() # -- init ------------------------------------------------------------------ def reset_parameters(self) -> None: nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: nn.init.zeros_(self.bias) self._cache_version += 1 # -- helpers --------------------------------------------------------------- def invalidate_packed(self) -> None: """Mark the packed cache stale. Called after weight mutations.""" self._cache_version += 1 # Free the dense fp32 cache too; next forward will rebuild it. if self._dense_w.numel() > 0: self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device) self._dense_version = -1 def _quantize_latent(self) -> Tuple[torch.Tensor, torch.Tensor]: """Quantise the FP32 latent weight to ternary (no-grad, no copy).""" with torch.no_grad(): w = self.weight alpha = _absmean_alpha(w) w_q = torch.round(torch.clamp(w / alpha.unsqueeze(-1), -1.0, 1.0)) if self.use_2_4: apply_2_4_sparsity_(w_q) return w_q.to(torch.int8), alpha def _ensure_packed(self) -> None: if self._packed_version == self._cache_version and self._packed.numel() > 0: return with torch.no_grad(): w_q, alpha = self._quantize_latent() ext = _NATIVE_EXT if ext is not None: packed = ext.pack_ternary(w_q) else: packed = pack_ternary(w_q) # Replace storage in-place to avoid breaking nn.Module buffer tracking. self._packed = packed.contiguous() self._alpha = alpha.contiguous() self._packed_version = self._cache_version @torch.no_grad() def prepare_for_inference(self) -> None: """Materialise the packed cache so the next forward is allocation-free.""" self.invalidate_packed() self._ensure_packed() @torch.no_grad() def ternary_nonzero_mask(self) -> torch.Tensor: """Boolean mask of currently non-zero ternary positions (cached).""" self._ensure_packed() # Reuse the dequantised float view through unpack — cheaper than a fresh # dense ternary tensor on small models, and shared for both branches. ext = _NATIVE_EXT if ext is not None: w = ext.unpack_ternary(self._packed, self.in_features) else: w = unpack_ternary(self._packed, self.in_features) return w.ne(0) # -- forward --------------------------------------------------------------- def _forward_train(self, x: torch.Tensor) -> torch.Tensor: """STE forward: differentiable, fp32/bf16 dense matmul.""" w = self.weight alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5) w_q = ste_ternary(w / alpha) * alpha if self.use_2_4: # 2:4 sparsity is non-differentiable but only zeros gradients on # already-pruned positions; safe to apply during STE forward. with torch.no_grad(): mask = (apply_2_4_sparsity_(w_q.detach().clone()) != 0).to(w_q.dtype) w_q = w_q * mask return F.linear(x, w_q.to(x.dtype), self.bias) def _ensure_dense(self) -> torch.Tensor: """Materialise (and cache) the fp32 dense ternary weight.""" self._ensure_packed() if self._dense_version == self._cache_version and self._dense_w.numel() > 0: return self._dense_w ext = _NATIVE_EXT if ext is not None: w = ext.dequantize(self._packed, self._alpha, self.in_features) else: w = unpack_ternary(self._packed, self.in_features) * self._alpha.unsqueeze(-1) # Replace the buffer in place so nn.Module book-keeping stays valid. self._dense_w = w.contiguous() self._dense_version = self._cache_version return self._dense_w def _forward_packed(self, x: torch.Tensor) -> torch.Tensor: """No-grad fast path that uses the cached dequantised weights.""" w = self._ensure_dense() # Match dtype (bf16 autocast support) without re-allocating the cache. if x.dtype != w.dtype: w_used = w.to(x.dtype) else: w_used = w return F.linear(x, w_used, self.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training and torch.is_grad_enabled(): return self._forward_train(x) return self._forward_packed(x) # -- introspection --------------------------------------------------------- def extra_repr(self) -> str: return (f"in_features={self.in_features}, out_features={self.out_features}, " f"bias={self.bias is not None}, nm_2_4={self.use_2_4}, " f"native={native_kernel_available()}") # --------------------------------------------------------------------------- # RMSNorm. # --------------------------------------------------------------------------- class RMSNorm(nn.Module): """Numerically-stable Root Mean Square LayerNorm (no bias, no centering).""" __constants__ = ["dim", "eps"] def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.dim = int(dim) self.eps = float(eps) self.weight = nn.Parameter(torch.ones(self.dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: # The normalisation is computed in fp32 for stability under bf16 # autocast, then cast back to the input dtype. dtype = x.dtype if dtype != torch.float32: x32 = x.float() rms = torch.rsqrt(x32.pow(2).mean(dim=-1, keepdim=True).add(self.eps)) return (x32 * rms).to(dtype) * self.weight rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True).add(self.eps)) return x * rms * self.weight __all__ = [ "BitLinear", "RMSNorm", "ste_ternary", "pack_ternary", "unpack_ternary", "ternarize_weight", "_quantize_weights_ternary", "apply_2_4_sparsity_", "enable_native_kernel", "native_kernel_available", ]