File size: 17,390 Bytes
11c11f8 31b0fdf 11c11f8 ec200d2 11c11f8 ec200d2 11c11f8 ec200d2 11c11f8 ec200d2 11c11f8 ec200d2 11c11f8 31b0fdf ec200d2 31b0fdf ec200d2 31b0fdf ec200d2 31b0fdf ec200d2 31b0fdf ec200d2 31b0fdf 11c11f8 31b0fdf 11c11f8 31b0fdf 11c11f8 31b0fdf 11c11f8 31b0fdf ec200d2 31b0fdf ec200d2 11c11f8 ec200d2 11c11f8 ec200d2 11c11f8 ec200d2 31b0fdf ec200d2 31b0fdf 11c11f8 31b0fdf ec200d2 31b0fdf 11c11f8 ec200d2 11c11f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 | """
Chimera 5.2 β 1.58-bit Ternary Compute (CPU-First, Slim)
========================================================
Single, clean implementation of BitNet-1.58 ternary linear layers.
Design goals:
* Zero overhead at import time (no JIT, no kernel discovery).
* One fast pure-PyTorch path that vectorises everything; an optional
C++/OpenMP path that is loaded *lazily* and only used when it actually
beats PyTorch (small batches on inference).
* Cache the packed 2-bit weights between forward calls and only repack
when the latent FP32 weights are mutated (training step or MeZO).
* No data-dependent Python loops, no per-row mask construction at init.
* torch.compile compatible: STE uses detach() trick (zero graph breaks).
Storage:
weight: FP32 latent of shape [M, K] (kept for STE backward / MeZO updates)
_packed: uint8 [M, ceil(K/4)] (2 bits per ternary value)
_alpha: fp32 [M] (per-row absolute mean scale)
Encoding (matches the C++ kernel):
-1 β 0b10
0 β 0b00
+1 β 0b01
"""
from __future__ import annotations
import math
import os
import threading
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Lazy C++ kernel.
# ---------------------------------------------------------------------------
_NATIVE_LOCK = threading.Lock()
_NATIVE_EXT: Optional[object] = None
_NATIVE_TRIED = False
_CPP_SOURCE = r"""
#include <torch/extension.h>
#include <cstdint>
#include <cmath>
#ifdef _OPENMP
#include <omp.h>
#endif
static const float LUT[4] = {0.0f, 1.0f, -1.0f, 0.0f};
torch::Tensor pack_ternary_cpu(torch::Tensor w) {
TORCH_CHECK(w.dim() == 2 && w.dtype() == torch::kInt8, "expected int8 [M,K]");
auto w_c = w.contiguous();
int64_t M = w_c.size(0), K = w_c.size(1);
int64_t K4 = (K + 3) / 4;
auto out = torch::zeros({M, K4}, torch::kUInt8);
const int8_t* s = w_c.data_ptr<int8_t>();
uint8_t* d = out.data_ptr<uint8_t>();
#pragma omp parallel for schedule(static)
for (int64_t m = 0; m < M; ++m) {
const int8_t* sr = s + m * K;
uint8_t* dr = d + m * K4;
for (int64_t k4 = 0; k4 < K4; ++k4) {
uint8_t b = 0;
for (int j = 0; j < 4; ++j) {
int64_t k = k4 * 4 + j;
if (k >= K) break;
int8_t v = sr[k];
uint8_t code = (v == 1) ? 1u : (v == -1 ? 2u : 0u);
b |= (code << (6 - j * 2));
}
dr[k4] = b;
}
}
return out;
}
torch::Tensor unpack_ternary_cpu(torch::Tensor packed, int64_t K) {
TORCH_CHECK(packed.dim() == 2 && packed.dtype() == torch::kUInt8, "expected uint8 [M,K4]");
auto p = packed.contiguous();
int64_t M = p.size(0), K4 = p.size(1);
auto out = torch::empty({M, K}, torch::kFloat32);
const uint8_t* pp = p.data_ptr<uint8_t>();
float* dp = out.data_ptr<float>();
#pragma omp parallel for schedule(static)
for (int64_t m = 0; m < M; ++m) {
const uint8_t* pr = pp + m * K4;
float* dr = dp + m * K;
for (int64_t k4 = 0; k4 < K4; ++k4) {
uint8_t b = pr[k4];
int64_t base = k4 * 4;
if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3];
if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3];
if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3];
if (base + 3 < K) dr[base + 3] = LUT[b & 3];
}
}
return out;
}
torch::Tensor dequantize_cpu(torch::Tensor packed, torch::Tensor alpha, int64_t K) {
auto p = packed.contiguous();
auto a = alpha.contiguous().to(torch::kFloat32);
int64_t M = p.size(0), K4 = p.size(1);
auto out = torch::empty({M, K}, torch::kFloat32);
const uint8_t* pp = p.data_ptr<uint8_t>();
const float* ap = a.data_ptr<float>();
float* dp = out.data_ptr<float>();
#pragma omp parallel for schedule(static)
for (int64_t m = 0; m < M; ++m) {
const uint8_t* pr = pp + m * K4;
float* dr = dp + m * K;
float sc = ap[m];
for (int64_t k4 = 0; k4 < K4; ++k4) {
uint8_t b = pr[k4];
int64_t base = k4 * 4;
if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3] * sc;
if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3] * sc;
if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3] * sc;
if (base + 3 < K) dr[base + 3] = LUT[b & 3] * sc;
}
}
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("pack_ternary", &pack_ternary_cpu, "Pack int8 ternary -> 2-bit uint8");
m.def("unpack_ternary", &unpack_ternary_cpu, "Unpack 2-bit uint8 -> fp32 {-1,0,1}");
m.def("dequantize", &dequantize_cpu, "Unpack and scale by per-row alpha");
}
"""
def _try_load_native() -> Optional[object]:
global _NATIVE_EXT, _NATIVE_TRIED
if _NATIVE_TRIED:
return _NATIVE_EXT
with _NATIVE_LOCK:
if _NATIVE_TRIED:
return _NATIVE_EXT
_NATIVE_TRIED = True
try:
from torch.utils.cpp_extension import load_inline
build_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", ".ternary_build"
)
os.makedirs(build_dir, exist_ok=True)
_NATIVE_EXT = load_inline(
name="chimera_ternary",
cpp_sources=_CPP_SOURCE,
extra_cflags=["-O3", "-fopenmp", "-ffast-math", "-funroll-loops"],
extra_ldflags=["-lgomp"],
build_directory=build_dir,
verbose=False,
)
except Exception as exc:
os.environ.setdefault("CHIMERA_NATIVE_DISABLED", str(exc)[:200])
_NATIVE_EXT = None
return _NATIVE_EXT
def enable_native_kernel(force: bool = False) -> bool:
global _NATIVE_TRIED
if force:
_NATIVE_TRIED = False
return _try_load_native() is not None
def native_kernel_available() -> bool:
return _NATIVE_EXT is not None
if os.environ.get("CHIMERA_NATIVE", "0") == "1":
enable_native_kernel()
# ---------------------------------------------------------------------------
# Pure PyTorch ternary primitives.
# ---------------------------------------------------------------------------
_TERNARY_LUT_F32 = torch.tensor([0.0, 1.0, -1.0, 0.0], dtype=torch.float32)
_TERNARY_LUT_I8 = torch.tensor([0, 1, -1, 0], dtype=torch.int8)
_SHIFTS = torch.tensor([6, 4, 2, 0], dtype=torch.uint8)
def pack_ternary(q: torch.Tensor) -> torch.Tensor:
q = q.detach()
if q.dim() == 1:
q = q.unsqueeze(0)
flat = q.reshape(-1, q.shape[-1]).to(torch.int8)
M, K = flat.shape
K4 = (K + 3) // 4
pad = K4 * 4 - K
if pad:
flat = F.pad(flat, (0, pad))
codes = torch.where(flat == 1, torch.full_like(flat, 1),
torch.where(flat == -1, torch.full_like(flat, 2), torch.zeros_like(flat))).to(torch.uint8)
codes = codes.view(M, K4, 4)
packed = ((codes[..., 0] << 6) | (codes[..., 1] << 4) |
(codes[..., 2] << 2) | codes[..., 3]).contiguous()
return packed.reshape(*q.shape[:-1], K4)
def unpack_ternary(packed: torch.Tensor, k: int,
alpha: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32) -> torch.Tensor:
packed = packed.to(torch.uint8)
if packed.dim() == 1:
packed = packed.unsqueeze(0)
flat = packed.reshape(-1, packed.shape[-1])
M, K4 = flat.shape
shifts = _SHIFTS.to(packed.device)
codes = (flat.unsqueeze(-1) >> shifts).bitwise_and_(3).to(torch.long)
lut = _TERNARY_LUT_F32.to(device=packed.device, dtype=dtype)
out = lut[codes].reshape(M, K4 * 4)[:, :k]
if alpha is not None:
out = out * alpha.reshape(M, 1).to(device=out.device, dtype=out.dtype)
return out.reshape(*packed.shape[:-1], k)
def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32)
def ternarize_weight(weight: torch.Tensor, group_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor]:
alpha = _absmean_alpha(weight)
w_q = torch.round(torch.clamp(weight / alpha.unsqueeze(-1), -1.0, 1.0)).to(torch.int8)
return w_q, alpha
_quantize_weights_ternary = ternarize_weight
def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
last = weight.shape[-1]
pad = (-last) % 4
target = F.pad(weight, (0, pad)) if pad else weight
view = target.view(*target.shape[:-1], -1, 4)
idx = view.abs().argsort(dim=-1)[..., :2]
view.scatter_(-1, idx, 0.0)
if pad:
weight.copy_(target[..., :last])
return weight
# ---------------------------------------------------------------------------
# Straight-Through Estimator for ternary quantization.
# ---------------------------------------------------------------------------
#
# CLAMP-AWARE STE using the detach() trick:
#
# clamped = clamp(w, -1, 1)
# w_q = clamped + (round(clamped) - clamped).detach()
#
# Forward: evaluates to round(clamp(w, -1, 1)) β same as before.
# Backward: β/βw [clamp(w, -1, 1)] = 1 if |w| <= 1 else 0.
# β Gradients are ZERO for weights outside [-1, 1] (at quantization boundary).
# β Gradients pass through unchanged inside [-1, 1] (STE identity).
#
# This prevents gradient explosion that caused NaN at step ~150 with the
# pure identity STE (w + (quant - w).detach()). The clamp derivative acts
# as a natural gradient gate: weights that have drifted beyond the ternary
# range get no gradient push, preventing runaway accumulation.
#
# Ref: 4-bit CPU training (arxiv:2603.13931) uses tanh soft clipping for
# the same stabilization purpose.
# ---------------------------------------------------------------------------
class _RoundTernarySTE(torch.autograd.Function):
"""LEGACY β kept for backward compat. Use ste_ternary() instead."""
@staticmethod
def forward(ctx, w: torch.Tensor) -> torch.Tensor:
return torch.round(torch.clamp(w, -1.0, 1.0))
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return grad_output.clamp(-1.0, 1.0)
def ste_ternary(w: torch.Tensor) -> torch.Tensor:
"""Straight-through estimator for ternary quantization.
Forward: round(clamp(w, -1, 1))
Backward: clamp derivative (zero outside [-1, 1], identity inside)
Uses the detach() trick for zero graph breaks under torch.compile.
"""
clamped = torch.clamp(w, -1.0, 1.0)
w_q = torch.round(clamped)
return clamped + (w_q - clamped).detach()
# ---------------------------------------------------------------------------
# BitLinear
# ---------------------------------------------------------------------------
class BitLinear(nn.Module):
"""Linear layer with ternary {-1, 0, 1} weights and per-row absmean scale.
*Training*: STE ternarisation with clamp-aware gradient gating.
*Inference*: cached packed 2-bit uint8 weights.
"""
__constants__ = ["in_features", "out_features", "use_2_4"]
def __init__(self, in_features: int, out_features: int, bias: bool = False,
group_size: int = 128, nm_2_4: bool = False):
super().__init__()
self.in_features = int(in_features)
self.out_features = int(out_features)
self.group_size = int(group_size)
self.use_2_4 = bool(nm_2_4)
self.weight = nn.Parameter(torch.empty(self.out_features, self.in_features))
if bias:
self.bias = nn.Parameter(torch.zeros(self.out_features))
else:
self.register_parameter("bias", None)
self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False)
self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False)
self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False)
self._packed_version = -1
self._dense_version = -1
self._cache_version = 0
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
nn.init.zeros_(self.bias)
self._cache_version += 1
def invalidate_packed(self) -> None:
self._cache_version += 1
if self._dense_w.numel() > 0:
self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device)
self._dense_version = -1
def _quantize_latent(self) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
w = self.weight
alpha = _absmean_alpha(w)
w_q = torch.round(torch.clamp(w / alpha.unsqueeze(-1), -1.0, 1.0))
if self.use_2_4:
apply_2_4_sparsity_(w_q)
return w_q.to(torch.int8), alpha
def _ensure_packed(self) -> None:
if self._packed_version == self._cache_version and self._packed.numel() > 0:
return
with torch.no_grad():
w_q, alpha = self._quantize_latent()
ext = _NATIVE_EXT
if ext is not None:
packed = ext.pack_ternary(w_q)
else:
packed = pack_ternary(w_q)
self._packed = packed.contiguous()
self._alpha = alpha.contiguous()
self._packed_version = self._cache_version
@torch.no_grad()
def prepare_for_inference(self) -> None:
self.invalidate_packed()
self._ensure_packed()
@torch.no_grad()
def ternary_nonzero_mask(self) -> torch.Tensor:
self._ensure_packed()
ext = _NATIVE_EXT
if ext is not None:
w = ext.unpack_ternary(self._packed, self.in_features)
else:
w = unpack_ternary(self._packed, self.in_features)
return w.ne(0)
def _forward_train(self, x: torch.Tensor) -> torch.Tensor:
"""STE forward with clamp-aware gradient gating.
The clamp on w_scaled ensures:
- Forward: round(clamp(w/alpha, -1, 1)) * alpha β correct ternary
- Backward: gradient is ZERO for w_scaled outside [-1, 1],
preventing gradient explosion from weights at the boundary.
"""
w = self.weight
alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5)
w_scaled = w / alpha
# Clamp FIRST, then detach the rounding residual.
# Gradient of clamp: 1 inside [-1,1], 0 outside β natural gradient gate
clamped = torch.clamp(w_scaled, -1.0, 1.0)
w_q = clamped + (torch.round(clamped) - clamped).detach()
w_q = w_q * alpha
if self.use_2_4:
with torch.no_grad():
mask = (apply_2_4_sparsity_(w_q.detach().clone()) != 0).to(w_q.dtype)
w_q = w_q * mask
return F.linear(x, w_q.to(x.dtype), self.bias)
def _ensure_dense(self) -> torch.Tensor:
self._ensure_packed()
if self._dense_version == self._cache_version and self._dense_w.numel() > 0:
return self._dense_w
ext = _NATIVE_EXT
if ext is not None:
w = ext.dequantize(self._packed, self._alpha, self.in_features)
else:
w = unpack_ternary(self._packed, self.in_features) * self._alpha.unsqueeze(-1)
self._dense_w = w.contiguous()
self._dense_version = self._cache_version
return self._dense_w
def _forward_packed(self, x: torch.Tensor) -> torch.Tensor:
w = self._ensure_dense()
if x.dtype != w.dtype:
w_used = w.to(x.dtype)
else:
w_used = w
return F.linear(x, w_used, self.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training and torch.is_grad_enabled():
return self._forward_train(x)
return self._forward_packed(x)
def extra_repr(self) -> str:
return (f"in_features={self.in_features}, out_features={self.out_features}, "
f"bias={self.bias is not None}, nm_2_4={self.use_2_4}, "
f"native={native_kernel_available()}")
# ---------------------------------------------------------------------------
# RMSNorm
# ---------------------------------------------------------------------------
class RMSNorm(nn.Module):
__constants__ = ["dim", "eps"]
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = int(dim)
self.eps = float(eps)
self.weight = nn.Parameter(torch.ones(self.dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
if dtype != torch.float32:
x32 = x.float()
rms = torch.rsqrt(x32.pow(2).mean(dim=-1, keepdim=True).add(self.eps))
return (x32 * rms).to(dtype) * self.weight
rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True).add(self.eps))
return x * rms * self.weight
__all__ = [
"BitLinear",
"RMSNorm",
"ste_ternary",
"pack_ternary",
"unpack_ternary",
"ternarize_weight",
"_quantize_weights_ternary",
"apply_2_4_sparsity_",
"enable_native_kernel",
"native_kernel_available",
]
|