chomera / chimera /quantization.py
Lgr54HFi's picture
fix: NaN at step 150 β€” add gradient clamping to STE detach trick + lower max_grad_norm to 0.5\n\nThe pure detach() STE passes gradients through unbounded, causing\ngradient explosion around step 140-150 when loss is still high.\n\nFix: clamp the gradient contribution within the detach trick:\n w_q = clamp(w_scaled, -1, 1) + (round(clamped) - clamped).detach()\nThis ensures gradients are zero outside [-1, 1] (weights already at\nquantization boundary get no gradient push) while keeping the STE\nidentity pass-through inside the valid range.\n\nAlso reduces max_grad_norm from 1.0 to 0.5 for additional stability.\n\nRef: 4-bit CPU training paper (2603.13931) uses tanh soft clipping\nfor the same reason."
ec200d2 verified
"""
Chimera 5.2 β€” 1.58-bit Ternary Compute (CPU-First, Slim)
========================================================
Single, clean implementation of BitNet-1.58 ternary linear layers.
Design goals:
* Zero overhead at import time (no JIT, no kernel discovery).
* One fast pure-PyTorch path that vectorises everything; an optional
C++/OpenMP path that is loaded *lazily* and only used when it actually
beats PyTorch (small batches on inference).
* Cache the packed 2-bit weights between forward calls and only repack
when the latent FP32 weights are mutated (training step or MeZO).
* No data-dependent Python loops, no per-row mask construction at init.
* torch.compile compatible: STE uses detach() trick (zero graph breaks).
Storage:
weight: FP32 latent of shape [M, K] (kept for STE backward / MeZO updates)
_packed: uint8 [M, ceil(K/4)] (2 bits per ternary value)
_alpha: fp32 [M] (per-row absolute mean scale)
Encoding (matches the C++ kernel):
-1 β†’ 0b10
0 β†’ 0b00
+1 β†’ 0b01
"""
from __future__ import annotations
import math
import os
import threading
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Lazy C++ kernel.
# ---------------------------------------------------------------------------
_NATIVE_LOCK = threading.Lock()
_NATIVE_EXT: Optional[object] = None
_NATIVE_TRIED = False
_CPP_SOURCE = r"""
#include <torch/extension.h>
#include <cstdint>
#include <cmath>
#ifdef _OPENMP
#include <omp.h>
#endif
static const float LUT[4] = {0.0f, 1.0f, -1.0f, 0.0f};
torch::Tensor pack_ternary_cpu(torch::Tensor w) {
TORCH_CHECK(w.dim() == 2 && w.dtype() == torch::kInt8, "expected int8 [M,K]");
auto w_c = w.contiguous();
int64_t M = w_c.size(0), K = w_c.size(1);
int64_t K4 = (K + 3) / 4;
auto out = torch::zeros({M, K4}, torch::kUInt8);
const int8_t* s = w_c.data_ptr<int8_t>();
uint8_t* d = out.data_ptr<uint8_t>();
#pragma omp parallel for schedule(static)
for (int64_t m = 0; m < M; ++m) {
const int8_t* sr = s + m * K;
uint8_t* dr = d + m * K4;
for (int64_t k4 = 0; k4 < K4; ++k4) {
uint8_t b = 0;
for (int j = 0; j < 4; ++j) {
int64_t k = k4 * 4 + j;
if (k >= K) break;
int8_t v = sr[k];
uint8_t code = (v == 1) ? 1u : (v == -1 ? 2u : 0u);
b |= (code << (6 - j * 2));
}
dr[k4] = b;
}
}
return out;
}
torch::Tensor unpack_ternary_cpu(torch::Tensor packed, int64_t K) {
TORCH_CHECK(packed.dim() == 2 && packed.dtype() == torch::kUInt8, "expected uint8 [M,K4]");
auto p = packed.contiguous();
int64_t M = p.size(0), K4 = p.size(1);
auto out = torch::empty({M, K}, torch::kFloat32);
const uint8_t* pp = p.data_ptr<uint8_t>();
float* dp = out.data_ptr<float>();
#pragma omp parallel for schedule(static)
for (int64_t m = 0; m < M; ++m) {
const uint8_t* pr = pp + m * K4;
float* dr = dp + m * K;
for (int64_t k4 = 0; k4 < K4; ++k4) {
uint8_t b = pr[k4];
int64_t base = k4 * 4;
if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3];
if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3];
if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3];
if (base + 3 < K) dr[base + 3] = LUT[b & 3];
}
}
return out;
}
torch::Tensor dequantize_cpu(torch::Tensor packed, torch::Tensor alpha, int64_t K) {
auto p = packed.contiguous();
auto a = alpha.contiguous().to(torch::kFloat32);
int64_t M = p.size(0), K4 = p.size(1);
auto out = torch::empty({M, K}, torch::kFloat32);
const uint8_t* pp = p.data_ptr<uint8_t>();
const float* ap = a.data_ptr<float>();
float* dp = out.data_ptr<float>();
#pragma omp parallel for schedule(static)
for (int64_t m = 0; m < M; ++m) {
const uint8_t* pr = pp + m * K4;
float* dr = dp + m * K;
float sc = ap[m];
for (int64_t k4 = 0; k4 < K4; ++k4) {
uint8_t b = pr[k4];
int64_t base = k4 * 4;
if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3] * sc;
if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3] * sc;
if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3] * sc;
if (base + 3 < K) dr[base + 3] = LUT[b & 3] * sc;
}
}
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("pack_ternary", &pack_ternary_cpu, "Pack int8 ternary -> 2-bit uint8");
m.def("unpack_ternary", &unpack_ternary_cpu, "Unpack 2-bit uint8 -> fp32 {-1,0,1}");
m.def("dequantize", &dequantize_cpu, "Unpack and scale by per-row alpha");
}
"""
def _try_load_native() -> Optional[object]:
global _NATIVE_EXT, _NATIVE_TRIED
if _NATIVE_TRIED:
return _NATIVE_EXT
with _NATIVE_LOCK:
if _NATIVE_TRIED:
return _NATIVE_EXT
_NATIVE_TRIED = True
try:
from torch.utils.cpp_extension import load_inline
build_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", ".ternary_build"
)
os.makedirs(build_dir, exist_ok=True)
_NATIVE_EXT = load_inline(
name="chimera_ternary",
cpp_sources=_CPP_SOURCE,
extra_cflags=["-O3", "-fopenmp", "-ffast-math", "-funroll-loops"],
extra_ldflags=["-lgomp"],
build_directory=build_dir,
verbose=False,
)
except Exception as exc:
os.environ.setdefault("CHIMERA_NATIVE_DISABLED", str(exc)[:200])
_NATIVE_EXT = None
return _NATIVE_EXT
def enable_native_kernel(force: bool = False) -> bool:
global _NATIVE_TRIED
if force:
_NATIVE_TRIED = False
return _try_load_native() is not None
def native_kernel_available() -> bool:
return _NATIVE_EXT is not None
if os.environ.get("CHIMERA_NATIVE", "0") == "1":
enable_native_kernel()
# ---------------------------------------------------------------------------
# Pure PyTorch ternary primitives.
# ---------------------------------------------------------------------------
_TERNARY_LUT_F32 = torch.tensor([0.0, 1.0, -1.0, 0.0], dtype=torch.float32)
_TERNARY_LUT_I8 = torch.tensor([0, 1, -1, 0], dtype=torch.int8)
_SHIFTS = torch.tensor([6, 4, 2, 0], dtype=torch.uint8)
def pack_ternary(q: torch.Tensor) -> torch.Tensor:
q = q.detach()
if q.dim() == 1:
q = q.unsqueeze(0)
flat = q.reshape(-1, q.shape[-1]).to(torch.int8)
M, K = flat.shape
K4 = (K + 3) // 4
pad = K4 * 4 - K
if pad:
flat = F.pad(flat, (0, pad))
codes = torch.where(flat == 1, torch.full_like(flat, 1),
torch.where(flat == -1, torch.full_like(flat, 2), torch.zeros_like(flat))).to(torch.uint8)
codes = codes.view(M, K4, 4)
packed = ((codes[..., 0] << 6) | (codes[..., 1] << 4) |
(codes[..., 2] << 2) | codes[..., 3]).contiguous()
return packed.reshape(*q.shape[:-1], K4)
def unpack_ternary(packed: torch.Tensor, k: int,
alpha: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32) -> torch.Tensor:
packed = packed.to(torch.uint8)
if packed.dim() == 1:
packed = packed.unsqueeze(0)
flat = packed.reshape(-1, packed.shape[-1])
M, K4 = flat.shape
shifts = _SHIFTS.to(packed.device)
codes = (flat.unsqueeze(-1) >> shifts).bitwise_and_(3).to(torch.long)
lut = _TERNARY_LUT_F32.to(device=packed.device, dtype=dtype)
out = lut[codes].reshape(M, K4 * 4)[:, :k]
if alpha is not None:
out = out * alpha.reshape(M, 1).to(device=out.device, dtype=out.dtype)
return out.reshape(*packed.shape[:-1], k)
def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32)
def ternarize_weight(weight: torch.Tensor, group_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor]:
alpha = _absmean_alpha(weight)
w_q = torch.round(torch.clamp(weight / alpha.unsqueeze(-1), -1.0, 1.0)).to(torch.int8)
return w_q, alpha
_quantize_weights_ternary = ternarize_weight
def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
last = weight.shape[-1]
pad = (-last) % 4
target = F.pad(weight, (0, pad)) if pad else weight
view = target.view(*target.shape[:-1], -1, 4)
idx = view.abs().argsort(dim=-1)[..., :2]
view.scatter_(-1, idx, 0.0)
if pad:
weight.copy_(target[..., :last])
return weight
# ---------------------------------------------------------------------------
# Straight-Through Estimator for ternary quantization.
# ---------------------------------------------------------------------------
#
# CLAMP-AWARE STE using the detach() trick:
#
# clamped = clamp(w, -1, 1)
# w_q = clamped + (round(clamped) - clamped).detach()
#
# Forward: evaluates to round(clamp(w, -1, 1)) β€” same as before.
# Backward: βˆ‚/βˆ‚w [clamp(w, -1, 1)] = 1 if |w| <= 1 else 0.
# β†’ Gradients are ZERO for weights outside [-1, 1] (at quantization boundary).
# β†’ Gradients pass through unchanged inside [-1, 1] (STE identity).
#
# This prevents gradient explosion that caused NaN at step ~150 with the
# pure identity STE (w + (quant - w).detach()). The clamp derivative acts
# as a natural gradient gate: weights that have drifted beyond the ternary
# range get no gradient push, preventing runaway accumulation.
#
# Ref: 4-bit CPU training (arxiv:2603.13931) uses tanh soft clipping for
# the same stabilization purpose.
# ---------------------------------------------------------------------------
class _RoundTernarySTE(torch.autograd.Function):
"""LEGACY β€” kept for backward compat. Use ste_ternary() instead."""
@staticmethod
def forward(ctx, w: torch.Tensor) -> torch.Tensor:
return torch.round(torch.clamp(w, -1.0, 1.0))
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return grad_output.clamp(-1.0, 1.0)
def ste_ternary(w: torch.Tensor) -> torch.Tensor:
"""Straight-through estimator for ternary quantization.
Forward: round(clamp(w, -1, 1))
Backward: clamp derivative (zero outside [-1, 1], identity inside)
Uses the detach() trick for zero graph breaks under torch.compile.
"""
clamped = torch.clamp(w, -1.0, 1.0)
w_q = torch.round(clamped)
return clamped + (w_q - clamped).detach()
# ---------------------------------------------------------------------------
# BitLinear
# ---------------------------------------------------------------------------
class BitLinear(nn.Module):
"""Linear layer with ternary {-1, 0, 1} weights and per-row absmean scale.
*Training*: STE ternarisation with clamp-aware gradient gating.
*Inference*: cached packed 2-bit uint8 weights.
"""
__constants__ = ["in_features", "out_features", "use_2_4"]
def __init__(self, in_features: int, out_features: int, bias: bool = False,
group_size: int = 128, nm_2_4: bool = False):
super().__init__()
self.in_features = int(in_features)
self.out_features = int(out_features)
self.group_size = int(group_size)
self.use_2_4 = bool(nm_2_4)
self.weight = nn.Parameter(torch.empty(self.out_features, self.in_features))
if bias:
self.bias = nn.Parameter(torch.zeros(self.out_features))
else:
self.register_parameter("bias", None)
self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False)
self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False)
self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False)
self._packed_version = -1
self._dense_version = -1
self._cache_version = 0
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
nn.init.zeros_(self.bias)
self._cache_version += 1
def invalidate_packed(self) -> None:
self._cache_version += 1
if self._dense_w.numel() > 0:
self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device)
self._dense_version = -1
def _quantize_latent(self) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
w = self.weight
alpha = _absmean_alpha(w)
w_q = torch.round(torch.clamp(w / alpha.unsqueeze(-1), -1.0, 1.0))
if self.use_2_4:
apply_2_4_sparsity_(w_q)
return w_q.to(torch.int8), alpha
def _ensure_packed(self) -> None:
if self._packed_version == self._cache_version and self._packed.numel() > 0:
return
with torch.no_grad():
w_q, alpha = self._quantize_latent()
ext = _NATIVE_EXT
if ext is not None:
packed = ext.pack_ternary(w_q)
else:
packed = pack_ternary(w_q)
self._packed = packed.contiguous()
self._alpha = alpha.contiguous()
self._packed_version = self._cache_version
@torch.no_grad()
def prepare_for_inference(self) -> None:
self.invalidate_packed()
self._ensure_packed()
@torch.no_grad()
def ternary_nonzero_mask(self) -> torch.Tensor:
self._ensure_packed()
ext = _NATIVE_EXT
if ext is not None:
w = ext.unpack_ternary(self._packed, self.in_features)
else:
w = unpack_ternary(self._packed, self.in_features)
return w.ne(0)
def _forward_train(self, x: torch.Tensor) -> torch.Tensor:
"""STE forward with clamp-aware gradient gating.
The clamp on w_scaled ensures:
- Forward: round(clamp(w/alpha, -1, 1)) * alpha β€” correct ternary
- Backward: gradient is ZERO for w_scaled outside [-1, 1],
preventing gradient explosion from weights at the boundary.
"""
w = self.weight
alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5)
w_scaled = w / alpha
# Clamp FIRST, then detach the rounding residual.
# Gradient of clamp: 1 inside [-1,1], 0 outside β†’ natural gradient gate
clamped = torch.clamp(w_scaled, -1.0, 1.0)
w_q = clamped + (torch.round(clamped) - clamped).detach()
w_q = w_q * alpha
if self.use_2_4:
with torch.no_grad():
mask = (apply_2_4_sparsity_(w_q.detach().clone()) != 0).to(w_q.dtype)
w_q = w_q * mask
return F.linear(x, w_q.to(x.dtype), self.bias)
def _ensure_dense(self) -> torch.Tensor:
self._ensure_packed()
if self._dense_version == self._cache_version and self._dense_w.numel() > 0:
return self._dense_w
ext = _NATIVE_EXT
if ext is not None:
w = ext.dequantize(self._packed, self._alpha, self.in_features)
else:
w = unpack_ternary(self._packed, self.in_features) * self._alpha.unsqueeze(-1)
self._dense_w = w.contiguous()
self._dense_version = self._cache_version
return self._dense_w
def _forward_packed(self, x: torch.Tensor) -> torch.Tensor:
w = self._ensure_dense()
if x.dtype != w.dtype:
w_used = w.to(x.dtype)
else:
w_used = w
return F.linear(x, w_used, self.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training and torch.is_grad_enabled():
return self._forward_train(x)
return self._forward_packed(x)
def extra_repr(self) -> str:
return (f"in_features={self.in_features}, out_features={self.out_features}, "
f"bias={self.bias is not None}, nm_2_4={self.use_2_4}, "
f"native={native_kernel_available()}")
# ---------------------------------------------------------------------------
# RMSNorm
# ---------------------------------------------------------------------------
class RMSNorm(nn.Module):
__constants__ = ["dim", "eps"]
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = int(dim)
self.eps = float(eps)
self.weight = nn.Parameter(torch.ones(self.dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
if dtype != torch.float32:
x32 = x.float()
rms = torch.rsqrt(x32.pow(2).mean(dim=-1, keepdim=True).add(self.eps))
return (x32 * rms).to(dtype) * self.weight
rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True).add(self.eps))
return x * rms * self.weight
__all__ = [
"BitLinear",
"RMSNorm",
"ste_ternary",
"pack_ternary",
"unpack_ternary",
"ternarize_weight",
"_quantize_weights_ternary",
"apply_2_4_sparsity_",
"enable_native_kernel",
"native_kernel_available",
]