ch1mera / chimera /quantization.py
Lgr54HFi's picture
Upload folder using huggingface_hub
6e408ce verified
"""
Chimera 5.2 — 1.58-bit Ternary Compute (CPU-First, Slim)
========================================================
Single, clean implementation of BitNet-1.58 ternary linear layers.
Design goals:
* Zero overhead at import time (no JIT, no kernel discovery).
* One fast pure-PyTorch path that vectorises everything; an optional
C++/OpenMP path that is loaded *lazily* and only used when it actually
beats PyTorch (small batches on inference).
* Cache the packed 2-bit weights between forward calls and only repack
when the latent FP32 weights are mutated (training step or MeZO).
* No data-dependent Python loops, no per-row mask construction at init.
Storage:
weight: FP32 latent of shape [M, K] (kept for STE backward / MeZO updates)
_packed: uint8 [M, ceil(K/4)] (2 bits per ternary value)
_alpha: fp32 [M] (per-row absolute mean scale)
Encoding (matches the C++ kernel):
-1 → 0b10
0 → 0b00
+1 → 0b01
"""
from __future__ import annotations
import math
import os
import threading
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Lazy C++ kernel. We never compile it during ``import``; it is only built
# when explicitly requested via :func:`enable_native_kernel` or the env var
# ``CHIMERA_NATIVE=1``. All public APIs work with the pure-PyTorch path.
# ---------------------------------------------------------------------------
_NATIVE_LOCK = threading.Lock()
_NATIVE_EXT: Optional[object] = None
_NATIVE_TRIED = False
_CPP_SOURCE = r"""
#include <torch/extension.h>
#include <cstdint>
#include <cmath>
#ifdef _OPENMP
#include <omp.h>
#endif
// Encoding: -1->0b10, 0->0b00, +1->0b01
static const float LUT[4] = {0.0f, 1.0f, -1.0f, 0.0f};
torch::Tensor pack_ternary_cpu(torch::Tensor w) {
TORCH_CHECK(w.dim() == 2 && w.dtype() == torch::kInt8, "expected int8 [M,K]");
auto w_c = w.contiguous();
int64_t M = w_c.size(0), K = w_c.size(1);
int64_t K4 = (K + 3) / 4;
auto out = torch::zeros({M, K4}, torch::kUInt8);
const int8_t* s = w_c.data_ptr<int8_t>();
uint8_t* d = out.data_ptr<uint8_t>();
#pragma omp parallel for schedule(static)
for (int64_t m = 0; m < M; ++m) {
const int8_t* sr = s + m * K;
uint8_t* dr = d + m * K4;
for (int64_t k4 = 0; k4 < K4; ++k4) {
uint8_t b = 0;
for (int j = 0; j < 4; ++j) {
int64_t k = k4 * 4 + j;
if (k >= K) break;
int8_t v = sr[k];
uint8_t code = (v == 1) ? 1u : (v == -1 ? 2u : 0u);
b |= (code << (6 - j * 2));
}
dr[k4] = b;
}
}
return out;
}
torch::Tensor unpack_ternary_cpu(torch::Tensor packed, int64_t K) {
TORCH_CHECK(packed.dim() == 2 && packed.dtype() == torch::kUInt8, "expected uint8 [M,K4]");
auto p = packed.contiguous();
int64_t M = p.size(0), K4 = p.size(1);
auto out = torch::empty({M, K}, torch::kFloat32);
const uint8_t* pp = p.data_ptr<uint8_t>();
float* dp = out.data_ptr<float>();
#pragma omp parallel for schedule(static)
for (int64_t m = 0; m < M; ++m) {
const uint8_t* pr = pp + m * K4;
float* dr = dp + m * K;
for (int64_t k4 = 0; k4 < K4; ++k4) {
uint8_t b = pr[k4];
int64_t base = k4 * 4;
if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3];
if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3];
if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3];
if (base + 3 < K) dr[base + 3] = LUT[b & 3];
}
}
return out;
}
// Fused "unpack and scale" -> bf16/fp32 dense weight. Saves a pass over memory
// and a temporary FP32 tensor when running under bf16 autocast.
torch::Tensor dequantize_cpu(torch::Tensor packed, torch::Tensor alpha, int64_t K) {
auto p = packed.contiguous();
auto a = alpha.contiguous().to(torch::kFloat32);
int64_t M = p.size(0), K4 = p.size(1);
auto out = torch::empty({M, K}, torch::kFloat32);
const uint8_t* pp = p.data_ptr<uint8_t>();
const float* ap = a.data_ptr<float>();
float* dp = out.data_ptr<float>();
#pragma omp parallel for schedule(static)
for (int64_t m = 0; m < M; ++m) {
const uint8_t* pr = pp + m * K4;
float* dr = dp + m * K;
float sc = ap[m];
for (int64_t k4 = 0; k4 < K4; ++k4) {
uint8_t b = pr[k4];
int64_t base = k4 * 4;
if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3] * sc;
if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3] * sc;
if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3] * sc;
if (base + 3 < K) dr[base + 3] = LUT[b & 3] * sc;
}
}
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("pack_ternary", &pack_ternary_cpu, "Pack int8 ternary -> 2-bit uint8");
m.def("unpack_ternary", &unpack_ternary_cpu, "Unpack 2-bit uint8 -> fp32 {-1,0,1}");
m.def("dequantize", &dequantize_cpu, "Unpack and scale by per-row alpha");
}
"""
def _try_load_native() -> Optional[object]:
"""Compile/load the optional native helper. Idempotent and thread-safe."""
global _NATIVE_EXT, _NATIVE_TRIED
if _NATIVE_TRIED:
return _NATIVE_EXT
with _NATIVE_LOCK:
if _NATIVE_TRIED:
return _NATIVE_EXT
_NATIVE_TRIED = True
try:
from torch.utils.cpp_extension import load_inline
build_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", ".ternary_build"
)
os.makedirs(build_dir, exist_ok=True)
_NATIVE_EXT = load_inline(
name="chimera_ternary",
cpp_sources=_CPP_SOURCE,
extra_cflags=["-O3", "-fopenmp", "-ffast-math", "-funroll-loops"],
extra_ldflags=["-lgomp"],
build_directory=build_dir,
verbose=False,
)
except Exception as exc: # pragma: no cover - best-effort.
os.environ.setdefault("CHIMERA_NATIVE_DISABLED", str(exc)[:200])
_NATIVE_EXT = None
return _NATIVE_EXT
def enable_native_kernel(force: bool = False) -> bool:
"""Eagerly try to compile the native kernel.
Returns ``True`` if the kernel is loaded and available.
"""
global _NATIVE_TRIED
if force:
_NATIVE_TRIED = False
return _try_load_native() is not None
def native_kernel_available() -> bool:
return _NATIVE_EXT is not None
# Allow opt-in from the environment without code changes.
if os.environ.get("CHIMERA_NATIVE", "0") == "1":
enable_native_kernel()
# ---------------------------------------------------------------------------
# Pure PyTorch ternary primitives (always available).
# ---------------------------------------------------------------------------
# Lookup tables compiled once. Casting to a registered buffer is overkill –
# they live on CPU and broadcast naturally.
_TERNARY_LUT_F32 = torch.tensor([0.0, 1.0, -1.0, 0.0], dtype=torch.float32)
_TERNARY_LUT_I8 = torch.tensor([0, 1, -1, 0], dtype=torch.int8)
_SHIFTS = torch.tensor([6, 4, 2, 0], dtype=torch.uint8)
def pack_ternary(q: torch.Tensor) -> torch.Tensor:
"""Pack a ternary {-1,0,1} tensor into a 2-bit uint8 tensor.
Vectorised pure-PyTorch implementation — no Python loops over rows.
Trailing positions that don't divide by four are zero-padded.
"""
q = q.detach()
if q.dim() == 1:
q = q.unsqueeze(0)
flat = q.reshape(-1, q.shape[-1]).to(torch.int8)
M, K = flat.shape
K4 = (K + 3) // 4
pad = K4 * 4 - K
if pad:
flat = F.pad(flat, (0, pad))
# codes: 0 / 1 / 2 (uint8)
codes = torch.where(flat == 1, torch.full_like(flat, 1),
torch.where(flat == -1, torch.full_like(flat, 2), torch.zeros_like(flat))).to(torch.uint8)
codes = codes.view(M, K4, 4)
packed = ((codes[..., 0] << 6) | (codes[..., 1] << 4) |
(codes[..., 2] << 2) | codes[..., 3]).contiguous()
return packed.reshape(*q.shape[:-1], K4)
def unpack_ternary(packed: torch.Tensor, k: int,
alpha: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Vectorised inverse of :func:`pack_ternary`.
Returns ``out`` with last dim ``k``; optionally pre-multiplied by
``alpha`` (per-row scale, broadcastable on the leading axes).
"""
packed = packed.to(torch.uint8)
if packed.dim() == 1:
packed = packed.unsqueeze(0)
flat = packed.reshape(-1, packed.shape[-1])
M, K4 = flat.shape
# Gather all 4 sub-positions in one vectorised op.
shifts = _SHIFTS.to(packed.device)
codes = (flat.unsqueeze(-1) >> shifts).bitwise_and_(3).to(torch.long) # [M, K4, 4]
lut = _TERNARY_LUT_F32.to(device=packed.device, dtype=dtype)
out = lut[codes].reshape(M, K4 * 4)[:, :k]
if alpha is not None:
out = out * alpha.reshape(M, 1).to(device=out.device, dtype=out.dtype)
return out.reshape(*packed.shape[:-1], k)
def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
"""Per-output-channel scale (``\alpha = mean|w|`` clamped)."""
return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32)
def ternarize_weight(weight: torch.Tensor, group_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantise FP32 weights to ternary using BitNet's abs-mean rule.
``group_size`` is kept for API compatibility but every row is its own
group in this slim implementation. Returns ``(w_ternary, alpha)``.
"""
alpha = _absmean_alpha(weight)
w_q = torch.round(torch.clamp(weight / alpha.unsqueeze(-1), -1.0, 1.0)).to(torch.int8)
return w_q, alpha
_quantize_weights_ternary = ternarize_weight # legacy alias used elsewhere
def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor:
"""In-place N:M 2:4 pruning. Vectorised — no Python row loops."""
with torch.no_grad():
last = weight.shape[-1]
pad = (-last) % 4
target = F.pad(weight, (0, pad)) if pad else weight
view = target.view(*target.shape[:-1], -1, 4)
# Keep the two largest in absolute value, zero the rest.
idx = view.abs().argsort(dim=-1)[..., :2]
view.scatter_(-1, idx, 0.0)
if pad:
weight.copy_(target[..., :last])
return weight
# ---------------------------------------------------------------------------
# Straight-Through Estimator for ternary quantization.
# ---------------------------------------------------------------------------
class _RoundTernarySTE(torch.autograd.Function):
@staticmethod
def forward(ctx, w: torch.Tensor) -> torch.Tensor: # type: ignore[override]
return torch.round(torch.clamp(w, -1.0, 1.0))
@staticmethod
def backward(ctx, grad_output: torch.Tensor): # type: ignore[override]
# Standard STE: gradient flows through, clipped to [-1, 1] so the
# latent FP32 weights cannot drift unboundedly.
return grad_output.clamp(-1.0, 1.0)
def ste_ternary(w: torch.Tensor) -> torch.Tensor:
return _RoundTernarySTE.apply(w)
# ---------------------------------------------------------------------------
# BitLinear — single class, single fast path.
# ---------------------------------------------------------------------------
class BitLinear(nn.Module):
"""Linear layer with ternary {-1, 0, 1} weights and per-row absmean scale.
*Training (grad-enabled)*: STE ternarisation on the latent weight, dense
fp32/bf16 matmul. Backward flows to the latent weight via STE.
*Inference / no-grad*: weights are quantised once and cached as packed
2-bit uint8 + fp32 alpha. Each forward unpacks (vectorised PyTorch or
optional C++ kernel) into a reusable buffer and calls a single matmul.
"""
__constants__ = ["in_features", "out_features", "use_2_4"]
def __init__(self, in_features: int, out_features: int, bias: bool = False,
group_size: int = 128, nm_2_4: bool = False):
super().__init__()
self.in_features = int(in_features)
self.out_features = int(out_features)
self.group_size = int(group_size)
self.use_2_4 = bool(nm_2_4)
self.weight = nn.Parameter(torch.empty(self.out_features, self.in_features))
if bias:
self.bias = nn.Parameter(torch.zeros(self.out_features))
else:
self.register_parameter("bias", None)
# Caches. ``_cache_version`` is bumped whenever the latent weight
# changes; the forward pass compares it against ``_packed_version``
# to know when to repack.
self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False)
self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False)
# Optional dense fp32 cache of the dequantised ternary weight. This
# is what every inference forward actually needs, so caching it
# eliminates the per-call unpack and saves ~30-50% of CPU time on
# small models. It is only built lazily on first inference call.
self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False)
self._packed_version = -1
self._dense_version = -1
self._cache_version = 0
self.reset_parameters()
# -- init ------------------------------------------------------------------
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
nn.init.zeros_(self.bias)
self._cache_version += 1
# -- helpers ---------------------------------------------------------------
def invalidate_packed(self) -> None:
"""Mark the packed cache stale. Called after weight mutations."""
self._cache_version += 1
# Free the dense fp32 cache too; next forward will rebuild it.
if self._dense_w.numel() > 0:
self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device)
self._dense_version = -1
def _quantize_latent(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantise the FP32 latent weight to ternary (no-grad, no copy)."""
with torch.no_grad():
w = self.weight
alpha = _absmean_alpha(w)
w_q = torch.round(torch.clamp(w / alpha.unsqueeze(-1), -1.0, 1.0))
if self.use_2_4:
apply_2_4_sparsity_(w_q)
return w_q.to(torch.int8), alpha
def _ensure_packed(self) -> None:
if self._packed_version == self._cache_version and self._packed.numel() > 0:
return
with torch.no_grad():
w_q, alpha = self._quantize_latent()
ext = _NATIVE_EXT
if ext is not None:
packed = ext.pack_ternary(w_q)
else:
packed = pack_ternary(w_q)
# Replace storage in-place to avoid breaking nn.Module buffer tracking.
self._packed = packed.contiguous()
self._alpha = alpha.contiguous()
self._packed_version = self._cache_version
@torch.no_grad()
def prepare_for_inference(self) -> None:
"""Materialise the packed cache so the next forward is allocation-free."""
self.invalidate_packed()
self._ensure_packed()
@torch.no_grad()
def ternary_nonzero_mask(self) -> torch.Tensor:
"""Boolean mask of currently non-zero ternary positions (cached)."""
self._ensure_packed()
# Reuse the dequantised float view through unpack — cheaper than a fresh
# dense ternary tensor on small models, and shared for both branches.
ext = _NATIVE_EXT
if ext is not None:
w = ext.unpack_ternary(self._packed, self.in_features)
else:
w = unpack_ternary(self._packed, self.in_features)
return w.ne(0)
# -- forward ---------------------------------------------------------------
def _forward_train(self, x: torch.Tensor) -> torch.Tensor:
"""STE forward: differentiable, fp32/bf16 dense matmul."""
w = self.weight
alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5)
w_q = ste_ternary(w / alpha) * alpha
if self.use_2_4:
# 2:4 sparsity is non-differentiable but only zeros gradients on
# already-pruned positions; safe to apply during STE forward.
with torch.no_grad():
mask = (apply_2_4_sparsity_(w_q.detach().clone()) != 0).to(w_q.dtype)
w_q = w_q * mask
return F.linear(x, w_q.to(x.dtype), self.bias)
def _ensure_dense(self) -> torch.Tensor:
"""Materialise (and cache) the fp32 dense ternary weight."""
self._ensure_packed()
if self._dense_version == self._cache_version and self._dense_w.numel() > 0:
return self._dense_w
ext = _NATIVE_EXT
if ext is not None:
w = ext.dequantize(self._packed, self._alpha, self.in_features)
else:
w = unpack_ternary(self._packed, self.in_features) * self._alpha.unsqueeze(-1)
# Replace the buffer in place so nn.Module book-keeping stays valid.
self._dense_w = w.contiguous()
self._dense_version = self._cache_version
return self._dense_w
def _forward_packed(self, x: torch.Tensor) -> torch.Tensor:
"""No-grad fast path that uses the cached dequantised weights."""
w = self._ensure_dense()
# Match dtype (bf16 autocast support) without re-allocating the cache.
if x.dtype != w.dtype:
w_used = w.to(x.dtype)
else:
w_used = w
return F.linear(x, w_used, self.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training and torch.is_grad_enabled():
return self._forward_train(x)
return self._forward_packed(x)
# -- introspection ---------------------------------------------------------
def extra_repr(self) -> str:
return (f"in_features={self.in_features}, out_features={self.out_features}, "
f"bias={self.bias is not None}, nm_2_4={self.use_2_4}, "
f"native={native_kernel_available()}")
# ---------------------------------------------------------------------------
# RMSNorm.
# ---------------------------------------------------------------------------
class RMSNorm(nn.Module):
"""Numerically-stable Root Mean Square LayerNorm (no bias, no centering)."""
__constants__ = ["dim", "eps"]
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = int(dim)
self.eps = float(eps)
self.weight = nn.Parameter(torch.ones(self.dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# The normalisation is computed in fp32 for stability under bf16
# autocast, then cast back to the input dtype.
dtype = x.dtype
if dtype != torch.float32:
x32 = x.float()
rms = torch.rsqrt(x32.pow(2).mean(dim=-1, keepdim=True).add(self.eps))
return (x32 * rms).to(dtype) * self.weight
rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True).add(self.eps))
return x * rms * self.weight
__all__ = [
"BitLinear",
"RMSNorm",
"ste_ternary",
"pack_ternary",
"unpack_ternary",
"ternarize_weight",
"_quantize_weights_ternary",
"apply_2_4_sparsity_",
"enable_native_kernel",
"native_kernel_available",
]