fix: NaN at step 150 β add gradient clamping to STE detach trick + lower max_grad_norm to 0.5\n\nThe pure detach() STE passes gradients through unbounded, causing\ngradient explosion around step 140-150 when loss is still high.\n\nFix: clamp the gradient contribution within the detach trick:\n w_q = clamp(w_scaled, -1, 1) + (round(clamped) - clamped).detach()\nThis ensures gradients are zero outside [-1, 1] (weights already at\nquantization boundary get no gradient push) while keeping the STE\nidentity pass-through inside the valid range.\n\nAlso reduces max_grad_norm from 1.0 to 0.5 for additional stability.\n\nRef: 4-bit CPU training paper (2603.13931) uses tanh soft clipping\nfor the same reason."
ec200d2 verified | """ | |
| Chimera 5.2 β 1.58-bit Ternary Compute (CPU-First, Slim) | |
| ======================================================== | |
| Single, clean implementation of BitNet-1.58 ternary linear layers. | |
| Design goals: | |
| * Zero overhead at import time (no JIT, no kernel discovery). | |
| * One fast pure-PyTorch path that vectorises everything; an optional | |
| C++/OpenMP path that is loaded *lazily* and only used when it actually | |
| beats PyTorch (small batches on inference). | |
| * Cache the packed 2-bit weights between forward calls and only repack | |
| when the latent FP32 weights are mutated (training step or MeZO). | |
| * No data-dependent Python loops, no per-row mask construction at init. | |
| * torch.compile compatible: STE uses detach() trick (zero graph breaks). | |
| Storage: | |
| weight: FP32 latent of shape [M, K] (kept for STE backward / MeZO updates) | |
| _packed: uint8 [M, ceil(K/4)] (2 bits per ternary value) | |
| _alpha: fp32 [M] (per-row absolute mean scale) | |
| Encoding (matches the C++ kernel): | |
| -1 β 0b10 | |
| 0 β 0b00 | |
| +1 β 0b01 | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import os | |
| import threading | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # --------------------------------------------------------------------------- | |
| # Lazy C++ kernel. | |
| # --------------------------------------------------------------------------- | |
| _NATIVE_LOCK = threading.Lock() | |
| _NATIVE_EXT: Optional[object] = None | |
| _NATIVE_TRIED = False | |
| _CPP_SOURCE = r""" | |
| #include <torch/extension.h> | |
| #include <cstdint> | |
| #include <cmath> | |
| #ifdef _OPENMP | |
| #include <omp.h> | |
| #endif | |
| static const float LUT[4] = {0.0f, 1.0f, -1.0f, 0.0f}; | |
| torch::Tensor pack_ternary_cpu(torch::Tensor w) { | |
| TORCH_CHECK(w.dim() == 2 && w.dtype() == torch::kInt8, "expected int8 [M,K]"); | |
| auto w_c = w.contiguous(); | |
| int64_t M = w_c.size(0), K = w_c.size(1); | |
| int64_t K4 = (K + 3) / 4; | |
| auto out = torch::zeros({M, K4}, torch::kUInt8); | |
| const int8_t* s = w_c.data_ptr<int8_t>(); | |
| uint8_t* d = out.data_ptr<uint8_t>(); | |
| #pragma omp parallel for schedule(static) | |
| for (int64_t m = 0; m < M; ++m) { | |
| const int8_t* sr = s + m * K; | |
| uint8_t* dr = d + m * K4; | |
| for (int64_t k4 = 0; k4 < K4; ++k4) { | |
| uint8_t b = 0; | |
| for (int j = 0; j < 4; ++j) { | |
| int64_t k = k4 * 4 + j; | |
| if (k >= K) break; | |
| int8_t v = sr[k]; | |
| uint8_t code = (v == 1) ? 1u : (v == -1 ? 2u : 0u); | |
| b |= (code << (6 - j * 2)); | |
| } | |
| dr[k4] = b; | |
| } | |
| } | |
| return out; | |
| } | |
| torch::Tensor unpack_ternary_cpu(torch::Tensor packed, int64_t K) { | |
| TORCH_CHECK(packed.dim() == 2 && packed.dtype() == torch::kUInt8, "expected uint8 [M,K4]"); | |
| auto p = packed.contiguous(); | |
| int64_t M = p.size(0), K4 = p.size(1); | |
| auto out = torch::empty({M, K}, torch::kFloat32); | |
| const uint8_t* pp = p.data_ptr<uint8_t>(); | |
| float* dp = out.data_ptr<float>(); | |
| #pragma omp parallel for schedule(static) | |
| for (int64_t m = 0; m < M; ++m) { | |
| const uint8_t* pr = pp + m * K4; | |
| float* dr = dp + m * K; | |
| for (int64_t k4 = 0; k4 < K4; ++k4) { | |
| uint8_t b = pr[k4]; | |
| int64_t base = k4 * 4; | |
| if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3]; | |
| if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3]; | |
| if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3]; | |
| if (base + 3 < K) dr[base + 3] = LUT[b & 3]; | |
| } | |
| } | |
| return out; | |
| } | |
| torch::Tensor dequantize_cpu(torch::Tensor packed, torch::Tensor alpha, int64_t K) { | |
| auto p = packed.contiguous(); | |
| auto a = alpha.contiguous().to(torch::kFloat32); | |
| int64_t M = p.size(0), K4 = p.size(1); | |
| auto out = torch::empty({M, K}, torch::kFloat32); | |
| const uint8_t* pp = p.data_ptr<uint8_t>(); | |
| const float* ap = a.data_ptr<float>(); | |
| float* dp = out.data_ptr<float>(); | |
| #pragma omp parallel for schedule(static) | |
| for (int64_t m = 0; m < M; ++m) { | |
| const uint8_t* pr = pp + m * K4; | |
| float* dr = dp + m * K; | |
| float sc = ap[m]; | |
| for (int64_t k4 = 0; k4 < K4; ++k4) { | |
| uint8_t b = pr[k4]; | |
| int64_t base = k4 * 4; | |
| if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3] * sc; | |
| if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3] * sc; | |
| if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3] * sc; | |
| if (base + 3 < K) dr[base + 3] = LUT[b & 3] * sc; | |
| } | |
| } | |
| return out; | |
| } | |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| m.def("pack_ternary", &pack_ternary_cpu, "Pack int8 ternary -> 2-bit uint8"); | |
| m.def("unpack_ternary", &unpack_ternary_cpu, "Unpack 2-bit uint8 -> fp32 {-1,0,1}"); | |
| m.def("dequantize", &dequantize_cpu, "Unpack and scale by per-row alpha"); | |
| } | |
| """ | |
| def _try_load_native() -> Optional[object]: | |
| global _NATIVE_EXT, _NATIVE_TRIED | |
| if _NATIVE_TRIED: | |
| return _NATIVE_EXT | |
| with _NATIVE_LOCK: | |
| if _NATIVE_TRIED: | |
| return _NATIVE_EXT | |
| _NATIVE_TRIED = True | |
| try: | |
| from torch.utils.cpp_extension import load_inline | |
| build_dir = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), "..", ".ternary_build" | |
| ) | |
| os.makedirs(build_dir, exist_ok=True) | |
| _NATIVE_EXT = load_inline( | |
| name="chimera_ternary", | |
| cpp_sources=_CPP_SOURCE, | |
| extra_cflags=["-O3", "-fopenmp", "-ffast-math", "-funroll-loops"], | |
| extra_ldflags=["-lgomp"], | |
| build_directory=build_dir, | |
| verbose=False, | |
| ) | |
| except Exception as exc: | |
| os.environ.setdefault("CHIMERA_NATIVE_DISABLED", str(exc)[:200]) | |
| _NATIVE_EXT = None | |
| return _NATIVE_EXT | |
| def enable_native_kernel(force: bool = False) -> bool: | |
| global _NATIVE_TRIED | |
| if force: | |
| _NATIVE_TRIED = False | |
| return _try_load_native() is not None | |
| def native_kernel_available() -> bool: | |
| return _NATIVE_EXT is not None | |
| if os.environ.get("CHIMERA_NATIVE", "0") == "1": | |
| enable_native_kernel() | |
| # --------------------------------------------------------------------------- | |
| # Pure PyTorch ternary primitives. | |
| # --------------------------------------------------------------------------- | |
| _TERNARY_LUT_F32 = torch.tensor([0.0, 1.0, -1.0, 0.0], dtype=torch.float32) | |
| _TERNARY_LUT_I8 = torch.tensor([0, 1, -1, 0], dtype=torch.int8) | |
| _SHIFTS = torch.tensor([6, 4, 2, 0], dtype=torch.uint8) | |
| def pack_ternary(q: torch.Tensor) -> torch.Tensor: | |
| q = q.detach() | |
| if q.dim() == 1: | |
| q = q.unsqueeze(0) | |
| flat = q.reshape(-1, q.shape[-1]).to(torch.int8) | |
| M, K = flat.shape | |
| K4 = (K + 3) // 4 | |
| pad = K4 * 4 - K | |
| if pad: | |
| flat = F.pad(flat, (0, pad)) | |
| codes = torch.where(flat == 1, torch.full_like(flat, 1), | |
| torch.where(flat == -1, torch.full_like(flat, 2), torch.zeros_like(flat))).to(torch.uint8) | |
| codes = codes.view(M, K4, 4) | |
| packed = ((codes[..., 0] << 6) | (codes[..., 1] << 4) | | |
| (codes[..., 2] << 2) | codes[..., 3]).contiguous() | |
| return packed.reshape(*q.shape[:-1], K4) | |
| def unpack_ternary(packed: torch.Tensor, k: int, | |
| alpha: Optional[torch.Tensor] = None, | |
| dtype: torch.dtype = torch.float32) -> torch.Tensor: | |
| packed = packed.to(torch.uint8) | |
| if packed.dim() == 1: | |
| packed = packed.unsqueeze(0) | |
| flat = packed.reshape(-1, packed.shape[-1]) | |
| M, K4 = flat.shape | |
| shifts = _SHIFTS.to(packed.device) | |
| codes = (flat.unsqueeze(-1) >> shifts).bitwise_and_(3).to(torch.long) | |
| lut = _TERNARY_LUT_F32.to(device=packed.device, dtype=dtype) | |
| out = lut[codes].reshape(M, K4 * 4)[:, :k] | |
| if alpha is not None: | |
| out = out * alpha.reshape(M, 1).to(device=out.device, dtype=out.dtype) | |
| return out.reshape(*packed.shape[:-1], k) | |
| def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: | |
| return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32) | |
| def ternarize_weight(weight: torch.Tensor, group_size: int = 128 | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| alpha = _absmean_alpha(weight) | |
| w_q = torch.round(torch.clamp(weight / alpha.unsqueeze(-1), -1.0, 1.0)).to(torch.int8) | |
| return w_q, alpha | |
| _quantize_weights_ternary = ternarize_weight | |
| def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor: | |
| with torch.no_grad(): | |
| last = weight.shape[-1] | |
| pad = (-last) % 4 | |
| target = F.pad(weight, (0, pad)) if pad else weight | |
| view = target.view(*target.shape[:-1], -1, 4) | |
| idx = view.abs().argsort(dim=-1)[..., :2] | |
| view.scatter_(-1, idx, 0.0) | |
| if pad: | |
| weight.copy_(target[..., :last]) | |
| return weight | |
| # --------------------------------------------------------------------------- | |
| # Straight-Through Estimator for ternary quantization. | |
| # --------------------------------------------------------------------------- | |
| # | |
| # CLAMP-AWARE STE using the detach() trick: | |
| # | |
| # clamped = clamp(w, -1, 1) | |
| # w_q = clamped + (round(clamped) - clamped).detach() | |
| # | |
| # Forward: evaluates to round(clamp(w, -1, 1)) β same as before. | |
| # Backward: β/βw [clamp(w, -1, 1)] = 1 if |w| <= 1 else 0. | |
| # β Gradients are ZERO for weights outside [-1, 1] (at quantization boundary). | |
| # β Gradients pass through unchanged inside [-1, 1] (STE identity). | |
| # | |
| # This prevents gradient explosion that caused NaN at step ~150 with the | |
| # pure identity STE (w + (quant - w).detach()). The clamp derivative acts | |
| # as a natural gradient gate: weights that have drifted beyond the ternary | |
| # range get no gradient push, preventing runaway accumulation. | |
| # | |
| # Ref: 4-bit CPU training (arxiv:2603.13931) uses tanh soft clipping for | |
| # the same stabilization purpose. | |
| # --------------------------------------------------------------------------- | |
| class _RoundTernarySTE(torch.autograd.Function): | |
| """LEGACY β kept for backward compat. Use ste_ternary() instead.""" | |
| def forward(ctx, w: torch.Tensor) -> torch.Tensor: | |
| return torch.round(torch.clamp(w, -1.0, 1.0)) | |
| def backward(ctx, grad_output: torch.Tensor): | |
| return grad_output.clamp(-1.0, 1.0) | |
| def ste_ternary(w: torch.Tensor) -> torch.Tensor: | |
| """Straight-through estimator for ternary quantization. | |
| Forward: round(clamp(w, -1, 1)) | |
| Backward: clamp derivative (zero outside [-1, 1], identity inside) | |
| Uses the detach() trick for zero graph breaks under torch.compile. | |
| """ | |
| clamped = torch.clamp(w, -1.0, 1.0) | |
| w_q = torch.round(clamped) | |
| return clamped + (w_q - clamped).detach() | |
| # --------------------------------------------------------------------------- | |
| # BitLinear | |
| # --------------------------------------------------------------------------- | |
| class BitLinear(nn.Module): | |
| """Linear layer with ternary {-1, 0, 1} weights and per-row absmean scale. | |
| *Training*: STE ternarisation with clamp-aware gradient gating. | |
| *Inference*: cached packed 2-bit uint8 weights. | |
| """ | |
| __constants__ = ["in_features", "out_features", "use_2_4"] | |
| def __init__(self, in_features: int, out_features: int, bias: bool = False, | |
| group_size: int = 128, nm_2_4: bool = False): | |
| super().__init__() | |
| self.in_features = int(in_features) | |
| self.out_features = int(out_features) | |
| self.group_size = int(group_size) | |
| self.use_2_4 = bool(nm_2_4) | |
| self.weight = nn.Parameter(torch.empty(self.out_features, self.in_features)) | |
| if bias: | |
| self.bias = nn.Parameter(torch.zeros(self.out_features)) | |
| else: | |
| self.register_parameter("bias", None) | |
| self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False) | |
| self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False) | |
| self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False) | |
| self._packed_version = -1 | |
| self._dense_version = -1 | |
| self._cache_version = 0 | |
| self.reset_parameters() | |
| def reset_parameters(self) -> None: | |
| nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |
| if self.bias is not None: | |
| nn.init.zeros_(self.bias) | |
| self._cache_version += 1 | |
| def invalidate_packed(self) -> None: | |
| self._cache_version += 1 | |
| if self._dense_w.numel() > 0: | |
| self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device) | |
| self._dense_version = -1 | |
| def _quantize_latent(self) -> Tuple[torch.Tensor, torch.Tensor]: | |
| with torch.no_grad(): | |
| w = self.weight | |
| alpha = _absmean_alpha(w) | |
| w_q = torch.round(torch.clamp(w / alpha.unsqueeze(-1), -1.0, 1.0)) | |
| if self.use_2_4: | |
| apply_2_4_sparsity_(w_q) | |
| return w_q.to(torch.int8), alpha | |
| def _ensure_packed(self) -> None: | |
| if self._packed_version == self._cache_version and self._packed.numel() > 0: | |
| return | |
| with torch.no_grad(): | |
| w_q, alpha = self._quantize_latent() | |
| ext = _NATIVE_EXT | |
| if ext is not None: | |
| packed = ext.pack_ternary(w_q) | |
| else: | |
| packed = pack_ternary(w_q) | |
| self._packed = packed.contiguous() | |
| self._alpha = alpha.contiguous() | |
| self._packed_version = self._cache_version | |
| def prepare_for_inference(self) -> None: | |
| self.invalidate_packed() | |
| self._ensure_packed() | |
| def ternary_nonzero_mask(self) -> torch.Tensor: | |
| self._ensure_packed() | |
| ext = _NATIVE_EXT | |
| if ext is not None: | |
| w = ext.unpack_ternary(self._packed, self.in_features) | |
| else: | |
| w = unpack_ternary(self._packed, self.in_features) | |
| return w.ne(0) | |
| def _forward_train(self, x: torch.Tensor) -> torch.Tensor: | |
| """STE forward with clamp-aware gradient gating. | |
| The clamp on w_scaled ensures: | |
| - Forward: round(clamp(w/alpha, -1, 1)) * alpha β correct ternary | |
| - Backward: gradient is ZERO for w_scaled outside [-1, 1], | |
| preventing gradient explosion from weights at the boundary. | |
| """ | |
| w = self.weight | |
| alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5) | |
| w_scaled = w / alpha | |
| # Clamp FIRST, then detach the rounding residual. | |
| # Gradient of clamp: 1 inside [-1,1], 0 outside β natural gradient gate | |
| clamped = torch.clamp(w_scaled, -1.0, 1.0) | |
| w_q = clamped + (torch.round(clamped) - clamped).detach() | |
| w_q = w_q * alpha | |
| if self.use_2_4: | |
| with torch.no_grad(): | |
| mask = (apply_2_4_sparsity_(w_q.detach().clone()) != 0).to(w_q.dtype) | |
| w_q = w_q * mask | |
| return F.linear(x, w_q.to(x.dtype), self.bias) | |
| def _ensure_dense(self) -> torch.Tensor: | |
| self._ensure_packed() | |
| if self._dense_version == self._cache_version and self._dense_w.numel() > 0: | |
| return self._dense_w | |
| ext = _NATIVE_EXT | |
| if ext is not None: | |
| w = ext.dequantize(self._packed, self._alpha, self.in_features) | |
| else: | |
| w = unpack_ternary(self._packed, self.in_features) * self._alpha.unsqueeze(-1) | |
| self._dense_w = w.contiguous() | |
| self._dense_version = self._cache_version | |
| return self._dense_w | |
| def _forward_packed(self, x: torch.Tensor) -> torch.Tensor: | |
| w = self._ensure_dense() | |
| if x.dtype != w.dtype: | |
| w_used = w.to(x.dtype) | |
| else: | |
| w_used = w | |
| return F.linear(x, w_used, self.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.training and torch.is_grad_enabled(): | |
| return self._forward_train(x) | |
| return self._forward_packed(x) | |
| def extra_repr(self) -> str: | |
| return (f"in_features={self.in_features}, out_features={self.out_features}, " | |
| f"bias={self.bias is not None}, nm_2_4={self.use_2_4}, " | |
| f"native={native_kernel_available()}") | |
| # --------------------------------------------------------------------------- | |
| # RMSNorm | |
| # --------------------------------------------------------------------------- | |
| class RMSNorm(nn.Module): | |
| __constants__ = ["dim", "eps"] | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.dim = int(dim) | |
| self.eps = float(eps) | |
| self.weight = nn.Parameter(torch.ones(self.dim)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| dtype = x.dtype | |
| if dtype != torch.float32: | |
| x32 = x.float() | |
| rms = torch.rsqrt(x32.pow(2).mean(dim=-1, keepdim=True).add(self.eps)) | |
| return (x32 * rms).to(dtype) * self.weight | |
| rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True).add(self.eps)) | |
| return x * rms * self.weight | |
| __all__ = [ | |
| "BitLinear", | |
| "RMSNorm", | |
| "ste_ternary", | |
| "pack_ternary", | |
| "unpack_ternary", | |
| "ternarize_weight", | |
| "_quantize_weights_ternary", | |
| "apply_2_4_sparsity_", | |
| "enable_native_kernel", | |
| "native_kernel_available", | |
| ] | |