| """ |
| Chimera 5.2 — 1.58-bit Ternary Compute (CPU-First, Slim) |
| ======================================================== |
| Single, clean implementation of BitNet-1.58 ternary linear layers. |
| |
| Design goals: |
| * Zero overhead at import time (no JIT, no kernel discovery). |
| * One fast pure-PyTorch path that vectorises everything; an optional |
| C++/OpenMP path that is loaded *lazily* and only used when it actually |
| beats PyTorch (small batches on inference). |
| * Cache the packed 2-bit weights between forward calls and only repack |
| when the latent FP32 weights are mutated (training step or MeZO). |
| * No data-dependent Python loops, no per-row mask construction at init. |
| |
| Storage: |
| weight: FP32 latent of shape [M, K] (kept for STE backward / MeZO updates) |
| _packed: uint8 [M, ceil(K/4)] (2 bits per ternary value) |
| _alpha: fp32 [M] (per-row absolute mean scale) |
| |
| Encoding (matches the C++ kernel): |
| -1 → 0b10 |
| 0 → 0b00 |
| +1 → 0b01 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| import os |
| import threading |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| _NATIVE_LOCK = threading.Lock() |
| _NATIVE_EXT: Optional[object] = None |
| _NATIVE_TRIED = False |
|
|
|
|
| _CPP_SOURCE = r""" |
| #include <torch/extension.h> |
| #include <cstdint> |
| #include <cmath> |
| #ifdef _OPENMP |
| #include <omp.h> |
| #endif |
| |
| // Encoding: -1->0b10, 0->0b00, +1->0b01 |
| static const float LUT[4] = {0.0f, 1.0f, -1.0f, 0.0f}; |
| |
| torch::Tensor pack_ternary_cpu(torch::Tensor w) { |
| TORCH_CHECK(w.dim() == 2 && w.dtype() == torch::kInt8, "expected int8 [M,K]"); |
| auto w_c = w.contiguous(); |
| int64_t M = w_c.size(0), K = w_c.size(1); |
| int64_t K4 = (K + 3) / 4; |
| auto out = torch::zeros({M, K4}, torch::kUInt8); |
| const int8_t* s = w_c.data_ptr<int8_t>(); |
| uint8_t* d = out.data_ptr<uint8_t>(); |
| #pragma omp parallel for schedule(static) |
| for (int64_t m = 0; m < M; ++m) { |
| const int8_t* sr = s + m * K; |
| uint8_t* dr = d + m * K4; |
| for (int64_t k4 = 0; k4 < K4; ++k4) { |
| uint8_t b = 0; |
| for (int j = 0; j < 4; ++j) { |
| int64_t k = k4 * 4 + j; |
| if (k >= K) break; |
| int8_t v = sr[k]; |
| uint8_t code = (v == 1) ? 1u : (v == -1 ? 2u : 0u); |
| b |= (code << (6 - j * 2)); |
| } |
| dr[k4] = b; |
| } |
| } |
| return out; |
| } |
| |
| torch::Tensor unpack_ternary_cpu(torch::Tensor packed, int64_t K) { |
| TORCH_CHECK(packed.dim() == 2 && packed.dtype() == torch::kUInt8, "expected uint8 [M,K4]"); |
| auto p = packed.contiguous(); |
| int64_t M = p.size(0), K4 = p.size(1); |
| auto out = torch::empty({M, K}, torch::kFloat32); |
| const uint8_t* pp = p.data_ptr<uint8_t>(); |
| float* dp = out.data_ptr<float>(); |
| #pragma omp parallel for schedule(static) |
| for (int64_t m = 0; m < M; ++m) { |
| const uint8_t* pr = pp + m * K4; |
| float* dr = dp + m * K; |
| for (int64_t k4 = 0; k4 < K4; ++k4) { |
| uint8_t b = pr[k4]; |
| int64_t base = k4 * 4; |
| if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3]; |
| if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3]; |
| if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3]; |
| if (base + 3 < K) dr[base + 3] = LUT[b & 3]; |
| } |
| } |
| return out; |
| } |
| |
| // Fused "unpack and scale" -> bf16/fp32 dense weight. Saves a pass over memory |
| // and a temporary FP32 tensor when running under bf16 autocast. |
| torch::Tensor dequantize_cpu(torch::Tensor packed, torch::Tensor alpha, int64_t K) { |
| auto p = packed.contiguous(); |
| auto a = alpha.contiguous().to(torch::kFloat32); |
| int64_t M = p.size(0), K4 = p.size(1); |
| auto out = torch::empty({M, K}, torch::kFloat32); |
| const uint8_t* pp = p.data_ptr<uint8_t>(); |
| const float* ap = a.data_ptr<float>(); |
| float* dp = out.data_ptr<float>(); |
| #pragma omp parallel for schedule(static) |
| for (int64_t m = 0; m < M; ++m) { |
| const uint8_t* pr = pp + m * K4; |
| float* dr = dp + m * K; |
| float sc = ap[m]; |
| for (int64_t k4 = 0; k4 < K4; ++k4) { |
| uint8_t b = pr[k4]; |
| int64_t base = k4 * 4; |
| if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3] * sc; |
| if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3] * sc; |
| if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3] * sc; |
| if (base + 3 < K) dr[base + 3] = LUT[b & 3] * sc; |
| } |
| } |
| return out; |
| } |
| |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| m.def("pack_ternary", &pack_ternary_cpu, "Pack int8 ternary -> 2-bit uint8"); |
| m.def("unpack_ternary", &unpack_ternary_cpu, "Unpack 2-bit uint8 -> fp32 {-1,0,1}"); |
| m.def("dequantize", &dequantize_cpu, "Unpack and scale by per-row alpha"); |
| } |
| """ |
|
|
|
|
| def _try_load_native() -> Optional[object]: |
| """Compile/load the optional native helper. Idempotent and thread-safe.""" |
| global _NATIVE_EXT, _NATIVE_TRIED |
| if _NATIVE_TRIED: |
| return _NATIVE_EXT |
| with _NATIVE_LOCK: |
| if _NATIVE_TRIED: |
| return _NATIVE_EXT |
| _NATIVE_TRIED = True |
| try: |
| from torch.utils.cpp_extension import load_inline |
|
|
| build_dir = os.path.join( |
| os.path.dirname(os.path.abspath(__file__)), "..", ".ternary_build" |
| ) |
| os.makedirs(build_dir, exist_ok=True) |
| _NATIVE_EXT = load_inline( |
| name="chimera_ternary", |
| cpp_sources=_CPP_SOURCE, |
| extra_cflags=["-O3", "-fopenmp", "-ffast-math", "-funroll-loops"], |
| extra_ldflags=["-lgomp"], |
| build_directory=build_dir, |
| verbose=False, |
| ) |
| except Exception as exc: |
| os.environ.setdefault("CHIMERA_NATIVE_DISABLED", str(exc)[:200]) |
| _NATIVE_EXT = None |
| return _NATIVE_EXT |
|
|
|
|
| def enable_native_kernel(force: bool = False) -> bool: |
| """Eagerly try to compile the native kernel. |
| |
| Returns ``True`` if the kernel is loaded and available. |
| """ |
| global _NATIVE_TRIED |
| if force: |
| _NATIVE_TRIED = False |
| return _try_load_native() is not None |
|
|
|
|
| def native_kernel_available() -> bool: |
| return _NATIVE_EXT is not None |
|
|
|
|
| |
| if os.environ.get("CHIMERA_NATIVE", "0") == "1": |
| enable_native_kernel() |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| _TERNARY_LUT_F32 = torch.tensor([0.0, 1.0, -1.0, 0.0], dtype=torch.float32) |
| _TERNARY_LUT_I8 = torch.tensor([0, 1, -1, 0], dtype=torch.int8) |
| _SHIFTS = torch.tensor([6, 4, 2, 0], dtype=torch.uint8) |
|
|
|
|
| def pack_ternary(q: torch.Tensor) -> torch.Tensor: |
| """Pack a ternary {-1,0,1} tensor into a 2-bit uint8 tensor. |
| |
| Vectorised pure-PyTorch implementation — no Python loops over rows. |
| Trailing positions that don't divide by four are zero-padded. |
| """ |
| q = q.detach() |
| if q.dim() == 1: |
| q = q.unsqueeze(0) |
| flat = q.reshape(-1, q.shape[-1]).to(torch.int8) |
| M, K = flat.shape |
| K4 = (K + 3) // 4 |
| pad = K4 * 4 - K |
| if pad: |
| flat = F.pad(flat, (0, pad)) |
| |
| codes = torch.where(flat == 1, torch.full_like(flat, 1), |
| torch.where(flat == -1, torch.full_like(flat, 2), torch.zeros_like(flat))).to(torch.uint8) |
| codes = codes.view(M, K4, 4) |
| packed = ((codes[..., 0] << 6) | (codes[..., 1] << 4) | |
| (codes[..., 2] << 2) | codes[..., 3]).contiguous() |
| return packed.reshape(*q.shape[:-1], K4) |
|
|
|
|
| def unpack_ternary(packed: torch.Tensor, k: int, |
| alpha: Optional[torch.Tensor] = None, |
| dtype: torch.dtype = torch.float32) -> torch.Tensor: |
| """Vectorised inverse of :func:`pack_ternary`. |
| |
| Returns ``out`` with last dim ``k``; optionally pre-multiplied by |
| ``alpha`` (per-row scale, broadcastable on the leading axes). |
| """ |
| packed = packed.to(torch.uint8) |
| if packed.dim() == 1: |
| packed = packed.unsqueeze(0) |
| flat = packed.reshape(-1, packed.shape[-1]) |
| M, K4 = flat.shape |
| |
| shifts = _SHIFTS.to(packed.device) |
| codes = (flat.unsqueeze(-1) >> shifts).bitwise_and_(3).to(torch.long) |
| lut = _TERNARY_LUT_F32.to(device=packed.device, dtype=dtype) |
| out = lut[codes].reshape(M, K4 * 4)[:, :k] |
| if alpha is not None: |
| out = out * alpha.reshape(M, 1).to(device=out.device, dtype=out.dtype) |
| return out.reshape(*packed.shape[:-1], k) |
|
|
|
|
| def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: |
| """Per-output-channel scale (``\alpha = mean|w|`` clamped).""" |
| return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32) |
|
|
|
|
| def ternarize_weight(weight: torch.Tensor, group_size: int = 128 |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Quantise FP32 weights to ternary using BitNet's abs-mean rule. |
| |
| ``group_size`` is kept for API compatibility but every row is its own |
| group in this slim implementation. Returns ``(w_ternary, alpha)``. |
| """ |
| alpha = _absmean_alpha(weight) |
| w_q = torch.round(torch.clamp(weight / alpha.unsqueeze(-1), -1.0, 1.0)).to(torch.int8) |
| return w_q, alpha |
|
|
|
|
| _quantize_weights_ternary = ternarize_weight |
|
|
|
|
| def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor: |
| """In-place N:M 2:4 pruning. Vectorised — no Python row loops.""" |
| with torch.no_grad(): |
| last = weight.shape[-1] |
| pad = (-last) % 4 |
| target = F.pad(weight, (0, pad)) if pad else weight |
| view = target.view(*target.shape[:-1], -1, 4) |
| |
| idx = view.abs().argsort(dim=-1)[..., :2] |
| view.scatter_(-1, idx, 0.0) |
| if pad: |
| weight.copy_(target[..., :last]) |
| return weight |
|
|
|
|
| |
| |
| |
|
|
| class _RoundTernarySTE(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, w: torch.Tensor) -> torch.Tensor: |
| return torch.round(torch.clamp(w, -1.0, 1.0)) |
|
|
| @staticmethod |
| def backward(ctx, grad_output: torch.Tensor): |
| |
| |
| return grad_output.clamp(-1.0, 1.0) |
|
|
|
|
| def ste_ternary(w: torch.Tensor) -> torch.Tensor: |
| return _RoundTernarySTE.apply(w) |
|
|
|
|
| |
| |
| |
|
|
| class BitLinear(nn.Module): |
| """Linear layer with ternary {-1, 0, 1} weights and per-row absmean scale. |
| |
| *Training (grad-enabled)*: STE ternarisation on the latent weight, dense |
| fp32/bf16 matmul. Backward flows to the latent weight via STE. |
| |
| *Inference / no-grad*: weights are quantised once and cached as packed |
| 2-bit uint8 + fp32 alpha. Each forward unpacks (vectorised PyTorch or |
| optional C++ kernel) into a reusable buffer and calls a single matmul. |
| """ |
|
|
| __constants__ = ["in_features", "out_features", "use_2_4"] |
|
|
| def __init__(self, in_features: int, out_features: int, bias: bool = False, |
| group_size: int = 128, nm_2_4: bool = False): |
| super().__init__() |
| self.in_features = int(in_features) |
| self.out_features = int(out_features) |
| self.group_size = int(group_size) |
| self.use_2_4 = bool(nm_2_4) |
|
|
| self.weight = nn.Parameter(torch.empty(self.out_features, self.in_features)) |
| if bias: |
| self.bias = nn.Parameter(torch.zeros(self.out_features)) |
| else: |
| self.register_parameter("bias", None) |
|
|
| |
| |
| |
| self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False) |
| self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False) |
| |
| |
| |
| |
| self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False) |
| self._packed_version = -1 |
| self._dense_version = -1 |
| self._cache_version = 0 |
|
|
| self.reset_parameters() |
|
|
| |
|
|
| def reset_parameters(self) -> None: |
| nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
| if self.bias is not None: |
| nn.init.zeros_(self.bias) |
| self._cache_version += 1 |
|
|
| |
|
|
| def invalidate_packed(self) -> None: |
| """Mark the packed cache stale. Called after weight mutations.""" |
| self._cache_version += 1 |
| |
| if self._dense_w.numel() > 0: |
| self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device) |
| self._dense_version = -1 |
|
|
| def _quantize_latent(self) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Quantise the FP32 latent weight to ternary (no-grad, no copy).""" |
| with torch.no_grad(): |
| w = self.weight |
| alpha = _absmean_alpha(w) |
| w_q = torch.round(torch.clamp(w / alpha.unsqueeze(-1), -1.0, 1.0)) |
| if self.use_2_4: |
| apply_2_4_sparsity_(w_q) |
| return w_q.to(torch.int8), alpha |
|
|
| def _ensure_packed(self) -> None: |
| if self._packed_version == self._cache_version and self._packed.numel() > 0: |
| return |
| with torch.no_grad(): |
| w_q, alpha = self._quantize_latent() |
| ext = _NATIVE_EXT |
| if ext is not None: |
| packed = ext.pack_ternary(w_q) |
| else: |
| packed = pack_ternary(w_q) |
| |
| self._packed = packed.contiguous() |
| self._alpha = alpha.contiguous() |
| self._packed_version = self._cache_version |
|
|
| @torch.no_grad() |
| def prepare_for_inference(self) -> None: |
| """Materialise the packed cache so the next forward is allocation-free.""" |
| self.invalidate_packed() |
| self._ensure_packed() |
|
|
| @torch.no_grad() |
| def ternary_nonzero_mask(self) -> torch.Tensor: |
| """Boolean mask of currently non-zero ternary positions (cached).""" |
| self._ensure_packed() |
| |
| |
| ext = _NATIVE_EXT |
| if ext is not None: |
| w = ext.unpack_ternary(self._packed, self.in_features) |
| else: |
| w = unpack_ternary(self._packed, self.in_features) |
| return w.ne(0) |
|
|
| |
|
|
| def _forward_train(self, x: torch.Tensor) -> torch.Tensor: |
| """STE forward: differentiable, fp32/bf16 dense matmul.""" |
| w = self.weight |
| alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5) |
| w_q = ste_ternary(w / alpha) * alpha |
| if self.use_2_4: |
| |
| |
| with torch.no_grad(): |
| mask = (apply_2_4_sparsity_(w_q.detach().clone()) != 0).to(w_q.dtype) |
| w_q = w_q * mask |
| return F.linear(x, w_q.to(x.dtype), self.bias) |
|
|
| def _ensure_dense(self) -> torch.Tensor: |
| """Materialise (and cache) the fp32 dense ternary weight.""" |
| self._ensure_packed() |
| if self._dense_version == self._cache_version and self._dense_w.numel() > 0: |
| return self._dense_w |
| ext = _NATIVE_EXT |
| if ext is not None: |
| w = ext.dequantize(self._packed, self._alpha, self.in_features) |
| else: |
| w = unpack_ternary(self._packed, self.in_features) * self._alpha.unsqueeze(-1) |
| |
| self._dense_w = w.contiguous() |
| self._dense_version = self._cache_version |
| return self._dense_w |
|
|
| def _forward_packed(self, x: torch.Tensor) -> torch.Tensor: |
| """No-grad fast path that uses the cached dequantised weights.""" |
| w = self._ensure_dense() |
| |
| if x.dtype != w.dtype: |
| w_used = w.to(x.dtype) |
| else: |
| w_used = w |
| return F.linear(x, w_used, self.bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.training and torch.is_grad_enabled(): |
| return self._forward_train(x) |
| return self._forward_packed(x) |
|
|
| |
|
|
| def extra_repr(self) -> str: |
| return (f"in_features={self.in_features}, out_features={self.out_features}, " |
| f"bias={self.bias is not None}, nm_2_4={self.use_2_4}, " |
| f"native={native_kernel_available()}") |
|
|
|
|
| |
| |
| |
|
|
| class RMSNorm(nn.Module): |
| """Numerically-stable Root Mean Square LayerNorm (no bias, no centering).""" |
|
|
| __constants__ = ["dim", "eps"] |
|
|
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.dim = int(dim) |
| self.eps = float(eps) |
| self.weight = nn.Parameter(torch.ones(self.dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| dtype = x.dtype |
| if dtype != torch.float32: |
| x32 = x.float() |
| rms = torch.rsqrt(x32.pow(2).mean(dim=-1, keepdim=True).add(self.eps)) |
| return (x32 * rms).to(dtype) * self.weight |
| rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True).add(self.eps)) |
| return x * rms * self.weight |
|
|
|
|
| __all__ = [ |
| "BitLinear", |
| "RMSNorm", |
| "ste_ternary", |
| "pack_ternary", |
| "unpack_ternary", |
| "ternarize_weight", |
| "_quantize_weights_ternary", |
| "apply_2_4_sparsity_", |
| "enable_native_kernel", |
| "native_kernel_available", |
| ] |
|
|