fix: NaN at step 150 — add gradient clamping to STE detach trick + lower max_grad_norm to 0.5\n\nThe pure detach() STE passes gradients through unbounded, causing\ngradient explosion around step 140-150 when loss is still high.\n\nFix: clamp the gradient contribution within the detach trick:\n w_q = clamp(w_scaled, -1, 1) + (round(clamped) - clamped).detach()\nThis ensures gradients are zero outside [-1, 1] (weights already at\nquantization boundary get no gradient push) while keeping the STE\nidentity pass-through inside the valid range.\n\nAlso reduces max_grad_norm from 1.0 to 0.5 for additional stability.\n\nRef: 4-bit CPU training paper (2603.13931) uses tanh soft clipping\nfor the same reason."
Browse files- chimera/quantization.py +35 -95
chimera/quantization.py
CHANGED
|
@@ -37,9 +37,7 @@ import torch.nn.functional as F
|
|
| 37 |
|
| 38 |
|
| 39 |
# ---------------------------------------------------------------------------
|
| 40 |
-
# Lazy C++ kernel.
|
| 41 |
-
# when explicitly requested via :func:`enable_native_kernel` or the env var
|
| 42 |
-
# ``CHIMERA_NATIVE=1``. All public APIs work with the pure-PyTorch path.
|
| 43 |
# ---------------------------------------------------------------------------
|
| 44 |
|
| 45 |
_NATIVE_LOCK = threading.Lock()
|
|
@@ -55,7 +53,6 @@ _CPP_SOURCE = r"""
|
|
| 55 |
#include <omp.h>
|
| 56 |
#endif
|
| 57 |
|
| 58 |
-
// Encoding: -1->0b10, 0->0b00, +1->0b01
|
| 59 |
static const float LUT[4] = {0.0f, 1.0f, -1.0f, 0.0f};
|
| 60 |
|
| 61 |
torch::Tensor pack_ternary_cpu(torch::Tensor w) {
|
|
@@ -108,8 +105,6 @@ torch::Tensor unpack_ternary_cpu(torch::Tensor packed, int64_t K) {
|
|
| 108 |
return out;
|
| 109 |
}
|
| 110 |
|
| 111 |
-
// Fused "unpack and scale" -> bf16/fp32 dense weight. Saves a pass over memory
|
| 112 |
-
// and a temporary FP32 tensor when running under bf16 autocast.
|
| 113 |
torch::Tensor dequantize_cpu(torch::Tensor packed, torch::Tensor alpha, int64_t K) {
|
| 114 |
auto p = packed.contiguous();
|
| 115 |
auto a = alpha.contiguous().to(torch::kFloat32);
|
|
@@ -144,7 +139,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
| 144 |
|
| 145 |
|
| 146 |
def _try_load_native() -> Optional[object]:
|
| 147 |
-
"""Compile/load the optional native helper. Idempotent and thread-safe."""
|
| 148 |
global _NATIVE_EXT, _NATIVE_TRIED
|
| 149 |
if _NATIVE_TRIED:
|
| 150 |
return _NATIVE_EXT
|
|
@@ -154,7 +148,6 @@ def _try_load_native() -> Optional[object]:
|
|
| 154 |
_NATIVE_TRIED = True
|
| 155 |
try:
|
| 156 |
from torch.utils.cpp_extension import load_inline
|
| 157 |
-
|
| 158 |
build_dir = os.path.join(
|
| 159 |
os.path.dirname(os.path.abspath(__file__)), "..", ".ternary_build"
|
| 160 |
)
|
|
@@ -167,17 +160,13 @@ def _try_load_native() -> Optional[object]:
|
|
| 167 |
build_directory=build_dir,
|
| 168 |
verbose=False,
|
| 169 |
)
|
| 170 |
-
except Exception as exc:
|
| 171 |
os.environ.setdefault("CHIMERA_NATIVE_DISABLED", str(exc)[:200])
|
| 172 |
_NATIVE_EXT = None
|
| 173 |
return _NATIVE_EXT
|
| 174 |
|
| 175 |
|
| 176 |
def enable_native_kernel(force: bool = False) -> bool:
|
| 177 |
-
"""Eagerly try to compile the native kernel.
|
| 178 |
-
|
| 179 |
-
Returns ``True`` if the kernel is loaded and available.
|
| 180 |
-
"""
|
| 181 |
global _NATIVE_TRIED
|
| 182 |
if force:
|
| 183 |
_NATIVE_TRIED = False
|
|
@@ -188,28 +177,20 @@ def native_kernel_available() -> bool:
|
|
| 188 |
return _NATIVE_EXT is not None
|
| 189 |
|
| 190 |
|
| 191 |
-
# Allow opt-in from the environment without code changes.
|
| 192 |
if os.environ.get("CHIMERA_NATIVE", "0") == "1":
|
| 193 |
enable_native_kernel()
|
| 194 |
|
| 195 |
|
| 196 |
# ---------------------------------------------------------------------------
|
| 197 |
-
# Pure PyTorch ternary primitives
|
| 198 |
# ---------------------------------------------------------------------------
|
| 199 |
|
| 200 |
-
# Lookup tables compiled once. Casting to a registered buffer is overkill –
|
| 201 |
-
# they live on CPU and broadcast naturally.
|
| 202 |
_TERNARY_LUT_F32 = torch.tensor([0.0, 1.0, -1.0, 0.0], dtype=torch.float32)
|
| 203 |
_TERNARY_LUT_I8 = torch.tensor([0, 1, -1, 0], dtype=torch.int8)
|
| 204 |
_SHIFTS = torch.tensor([6, 4, 2, 0], dtype=torch.uint8)
|
| 205 |
|
| 206 |
|
| 207 |
def pack_ternary(q: torch.Tensor) -> torch.Tensor:
|
| 208 |
-
"""Pack a ternary {-1,0,1} tensor into a 2-bit uint8 tensor.
|
| 209 |
-
|
| 210 |
-
Vectorised pure-PyTorch implementation — no Python loops over rows.
|
| 211 |
-
Trailing positions that don't divide by four are zero-padded.
|
| 212 |
-
"""
|
| 213 |
q = q.detach()
|
| 214 |
if q.dim() == 1:
|
| 215 |
q = q.unsqueeze(0)
|
|
@@ -219,7 +200,6 @@ def pack_ternary(q: torch.Tensor) -> torch.Tensor:
|
|
| 219 |
pad = K4 * 4 - K
|
| 220 |
if pad:
|
| 221 |
flat = F.pad(flat, (0, pad))
|
| 222 |
-
# codes: 0 / 1 / 2 (uint8)
|
| 223 |
codes = torch.where(flat == 1, torch.full_like(flat, 1),
|
| 224 |
torch.where(flat == -1, torch.full_like(flat, 2), torch.zeros_like(flat))).to(torch.uint8)
|
| 225 |
codes = codes.view(M, K4, 4)
|
|
@@ -231,19 +211,13 @@ def pack_ternary(q: torch.Tensor) -> torch.Tensor:
|
|
| 231 |
def unpack_ternary(packed: torch.Tensor, k: int,
|
| 232 |
alpha: Optional[torch.Tensor] = None,
|
| 233 |
dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
| 234 |
-
"""Vectorised inverse of :func:`pack_ternary`.
|
| 235 |
-
|
| 236 |
-
Returns ``out`` with last dim ``k``; optionally pre-multiplied by
|
| 237 |
-
``alpha`` (per-row scale, broadcastable on the leading axes).
|
| 238 |
-
"""
|
| 239 |
packed = packed.to(torch.uint8)
|
| 240 |
if packed.dim() == 1:
|
| 241 |
packed = packed.unsqueeze(0)
|
| 242 |
flat = packed.reshape(-1, packed.shape[-1])
|
| 243 |
M, K4 = flat.shape
|
| 244 |
-
# Gather all 4 sub-positions in one vectorised op.
|
| 245 |
shifts = _SHIFTS.to(packed.device)
|
| 246 |
-
codes = (flat.unsqueeze(-1) >> shifts).bitwise_and_(3).to(torch.long)
|
| 247 |
lut = _TERNARY_LUT_F32.to(device=packed.device, dtype=dtype)
|
| 248 |
out = lut[codes].reshape(M, K4 * 4)[:, :k]
|
| 249 |
if alpha is not None:
|
|
@@ -252,33 +226,25 @@ def unpack_ternary(packed: torch.Tensor, k: int,
|
|
| 252 |
|
| 253 |
|
| 254 |
def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
|
| 255 |
-
"""Per-output-channel scale (``\\alpha = mean|w|`` clamped)."""
|
| 256 |
return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32)
|
| 257 |
|
| 258 |
|
| 259 |
def ternarize_weight(weight: torch.Tensor, group_size: int = 128
|
| 260 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 261 |
-
"""Quantise FP32 weights to ternary using BitNet's abs-mean rule.
|
| 262 |
-
|
| 263 |
-
``group_size`` is kept for API compatibility but every row is its own
|
| 264 |
-
group in this slim implementation. Returns ``(w_ternary, alpha)``.
|
| 265 |
-
"""
|
| 266 |
alpha = _absmean_alpha(weight)
|
| 267 |
w_q = torch.round(torch.clamp(weight / alpha.unsqueeze(-1), -1.0, 1.0)).to(torch.int8)
|
| 268 |
return w_q, alpha
|
| 269 |
|
| 270 |
|
| 271 |
-
_quantize_weights_ternary = ternarize_weight
|
| 272 |
|
| 273 |
|
| 274 |
def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor:
|
| 275 |
-
"""In-place N:M 2:4 pruning. Vectorised — no Python row loops."""
|
| 276 |
with torch.no_grad():
|
| 277 |
last = weight.shape[-1]
|
| 278 |
pad = (-last) % 4
|
| 279 |
target = F.pad(weight, (0, pad)) if pad else weight
|
| 280 |
view = target.view(*target.shape[:-1], -1, 4)
|
| 281 |
-
# Keep the two largest in absolute value, zero the rest.
|
| 282 |
idx = view.abs().argsort(dim=-1)[..., :2]
|
| 283 |
view.scatter_(-1, idx, 0.0)
|
| 284 |
if pad:
|
|
@@ -290,29 +256,25 @@ def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor:
|
|
| 290 |
# Straight-Through Estimator for ternary quantization.
|
| 291 |
# ---------------------------------------------------------------------------
|
| 292 |
#
|
| 293 |
-
#
|
| 294 |
#
|
| 295 |
-
#
|
|
|
|
| 296 |
#
|
| 297 |
-
# Forward: evaluates to round(clamp(w, -1, 1))
|
| 298 |
-
# Backward: ∂/∂w [w
|
|
|
|
|
|
|
| 299 |
#
|
| 300 |
-
# This
|
| 301 |
-
#
|
| 302 |
-
#
|
| 303 |
-
#
|
| 304 |
#
|
| 305 |
-
#
|
| 306 |
-
#
|
| 307 |
-
# Note: the old STE also clipped gradients to [-1, 1]. The detach trick
|
| 308 |
-
# passes gradients through unclipped, which is actually better for convergence
|
| 309 |
-
# (see BitNet b1.58 Reloaded, arxiv:2407.09527). If you need grad clipping,
|
| 310 |
-
# use torch.nn.utils.clip_grad_norm_() at the optimizer step instead.
|
| 311 |
# ---------------------------------------------------------------------------
|
| 312 |
|
| 313 |
-
# Keep the old class around for backward compatibility (MeZOOptimizer uses it
|
| 314 |
-
# indirectly through ternary_nonzero_mask), but it is no longer called in the
|
| 315 |
-
# training forward path.
|
| 316 |
class _RoundTernarySTE(torch.autograd.Function):
|
| 317 |
"""LEGACY — kept for backward compat. Use ste_ternary() instead."""
|
| 318 |
@staticmethod
|
|
@@ -328,28 +290,24 @@ def ste_ternary(w: torch.Tensor) -> torch.Tensor:
|
|
| 328 |
"""Straight-through estimator for ternary quantization.
|
| 329 |
|
| 330 |
Forward: round(clamp(w, -1, 1))
|
| 331 |
-
Backward:
|
| 332 |
|
| 333 |
Uses the detach() trick for zero graph breaks under torch.compile.
|
| 334 |
"""
|
| 335 |
-
|
| 336 |
-
|
|
|
|
| 337 |
|
| 338 |
|
| 339 |
# ---------------------------------------------------------------------------
|
| 340 |
-
# BitLinear
|
| 341 |
# ---------------------------------------------------------------------------
|
| 342 |
|
| 343 |
class BitLinear(nn.Module):
|
| 344 |
"""Linear layer with ternary {-1, 0, 1} weights and per-row absmean scale.
|
| 345 |
|
| 346 |
-
*Training
|
| 347 |
-
|
| 348 |
-
Uses detach() trick — fully torch.compile compatible (zero graph breaks).
|
| 349 |
-
|
| 350 |
-
*Inference / no-grad*: weights are quantised once and cached as packed
|
| 351 |
-
2-bit uint8 + fp32 alpha. Each forward unpacks (vectorised PyTorch or
|
| 352 |
-
optional C++ kernel) into a reusable buffer and calls a single matmul.
|
| 353 |
"""
|
| 354 |
|
| 355 |
__constants__ = ["in_features", "out_features", "use_2_4"]
|
|
@@ -368,7 +326,6 @@ class BitLinear(nn.Module):
|
|
| 368 |
else:
|
| 369 |
self.register_parameter("bias", None)
|
| 370 |
|
| 371 |
-
# Caches for inference path.
|
| 372 |
self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False)
|
| 373 |
self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False)
|
| 374 |
self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False)
|
|
@@ -378,25 +335,19 @@ class BitLinear(nn.Module):
|
|
| 378 |
|
| 379 |
self.reset_parameters()
|
| 380 |
|
| 381 |
-
# -- init ------------------------------------------------------------------
|
| 382 |
-
|
| 383 |
def reset_parameters(self) -> None:
|
| 384 |
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 385 |
if self.bias is not None:
|
| 386 |
nn.init.zeros_(self.bias)
|
| 387 |
self._cache_version += 1
|
| 388 |
|
| 389 |
-
# -- helpers ---------------------------------------------------------------
|
| 390 |
-
|
| 391 |
def invalidate_packed(self) -> None:
|
| 392 |
-
"""Mark the packed cache stale. Called after weight mutations."""
|
| 393 |
self._cache_version += 1
|
| 394 |
if self._dense_w.numel() > 0:
|
| 395 |
self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device)
|
| 396 |
self._dense_version = -1
|
| 397 |
|
| 398 |
def _quantize_latent(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 399 |
-
"""Quantise the FP32 latent weight to ternary (no-grad, no copy)."""
|
| 400 |
with torch.no_grad():
|
| 401 |
w = self.weight
|
| 402 |
alpha = _absmean_alpha(w)
|
|
@@ -421,13 +372,11 @@ class BitLinear(nn.Module):
|
|
| 421 |
|
| 422 |
@torch.no_grad()
|
| 423 |
def prepare_for_inference(self) -> None:
|
| 424 |
-
"""Materialise the packed cache so the next forward is allocation-free."""
|
| 425 |
self.invalidate_packed()
|
| 426 |
self._ensure_packed()
|
| 427 |
|
| 428 |
@torch.no_grad()
|
| 429 |
def ternary_nonzero_mask(self) -> torch.Tensor:
|
| 430 |
-
"""Boolean mask of currently non-zero ternary positions (cached)."""
|
| 431 |
self._ensure_packed()
|
| 432 |
ext = _NATIVE_EXT
|
| 433 |
if ext is not None:
|
|
@@ -436,24 +385,21 @@ class BitLinear(nn.Module):
|
|
| 436 |
w = unpack_ternary(self._packed, self.in_features)
|
| 437 |
return w.ne(0)
|
| 438 |
|
| 439 |
-
# -- forward ---------------------------------------------------------------
|
| 440 |
-
|
| 441 |
def _forward_train(self, x: torch.Tensor) -> torch.Tensor:
|
| 442 |
-
"""STE forward
|
| 443 |
-
|
| 444 |
-
Uses detach() trick for torch.compile compatibility:
|
| 445 |
-
w_scaled = w / alpha
|
| 446 |
-
w_q = w_scaled + (round(clamp(w_scaled)) - w_scaled).detach()
|
| 447 |
-
output = F.linear(x, w_q * alpha)
|
| 448 |
|
| 449 |
-
|
| 450 |
-
|
|
|
|
|
|
|
| 451 |
"""
|
| 452 |
w = self.weight
|
| 453 |
alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5)
|
| 454 |
w_scaled = w / alpha
|
| 455 |
-
#
|
| 456 |
-
|
|
|
|
|
|
|
| 457 |
w_q = w_q * alpha
|
| 458 |
if self.use_2_4:
|
| 459 |
with torch.no_grad():
|
|
@@ -462,7 +408,6 @@ class BitLinear(nn.Module):
|
|
| 462 |
return F.linear(x, w_q.to(x.dtype), self.bias)
|
| 463 |
|
| 464 |
def _ensure_dense(self) -> torch.Tensor:
|
| 465 |
-
"""Materialise (and cache) the fp32 dense ternary weight."""
|
| 466 |
self._ensure_packed()
|
| 467 |
if self._dense_version == self._cache_version and self._dense_w.numel() > 0:
|
| 468 |
return self._dense_w
|
|
@@ -476,7 +421,6 @@ class BitLinear(nn.Module):
|
|
| 476 |
return self._dense_w
|
| 477 |
|
| 478 |
def _forward_packed(self, x: torch.Tensor) -> torch.Tensor:
|
| 479 |
-
"""No-grad fast path that uses the cached dequantised weights."""
|
| 480 |
w = self._ensure_dense()
|
| 481 |
if x.dtype != w.dtype:
|
| 482 |
w_used = w.to(x.dtype)
|
|
@@ -489,8 +433,6 @@ class BitLinear(nn.Module):
|
|
| 489 |
return self._forward_train(x)
|
| 490 |
return self._forward_packed(x)
|
| 491 |
|
| 492 |
-
# -- introspection ---------------------------------------------------------
|
| 493 |
-
|
| 494 |
def extra_repr(self) -> str:
|
| 495 |
return (f"in_features={self.in_features}, out_features={self.out_features}, "
|
| 496 |
f"bias={self.bias is not None}, nm_2_4={self.use_2_4}, "
|
|
@@ -498,12 +440,10 @@ class BitLinear(nn.Module):
|
|
| 498 |
|
| 499 |
|
| 500 |
# ---------------------------------------------------------------------------
|
| 501 |
-
# RMSNorm
|
| 502 |
# ---------------------------------------------------------------------------
|
| 503 |
|
| 504 |
class RMSNorm(nn.Module):
|
| 505 |
-
"""Numerically-stable Root Mean Square LayerNorm (no bias, no centering)."""
|
| 506 |
-
|
| 507 |
__constants__ = ["dim", "eps"]
|
| 508 |
|
| 509 |
def __init__(self, dim: int, eps: float = 1e-6):
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
# ---------------------------------------------------------------------------
|
| 40 |
+
# Lazy C++ kernel.
|
|
|
|
|
|
|
| 41 |
# ---------------------------------------------------------------------------
|
| 42 |
|
| 43 |
_NATIVE_LOCK = threading.Lock()
|
|
|
|
| 53 |
#include <omp.h>
|
| 54 |
#endif
|
| 55 |
|
|
|
|
| 56 |
static const float LUT[4] = {0.0f, 1.0f, -1.0f, 0.0f};
|
| 57 |
|
| 58 |
torch::Tensor pack_ternary_cpu(torch::Tensor w) {
|
|
|
|
| 105 |
return out;
|
| 106 |
}
|
| 107 |
|
|
|
|
|
|
|
| 108 |
torch::Tensor dequantize_cpu(torch::Tensor packed, torch::Tensor alpha, int64_t K) {
|
| 109 |
auto p = packed.contiguous();
|
| 110 |
auto a = alpha.contiguous().to(torch::kFloat32);
|
|
|
|
| 139 |
|
| 140 |
|
| 141 |
def _try_load_native() -> Optional[object]:
|
|
|
|
| 142 |
global _NATIVE_EXT, _NATIVE_TRIED
|
| 143 |
if _NATIVE_TRIED:
|
| 144 |
return _NATIVE_EXT
|
|
|
|
| 148 |
_NATIVE_TRIED = True
|
| 149 |
try:
|
| 150 |
from torch.utils.cpp_extension import load_inline
|
|
|
|
| 151 |
build_dir = os.path.join(
|
| 152 |
os.path.dirname(os.path.abspath(__file__)), "..", ".ternary_build"
|
| 153 |
)
|
|
|
|
| 160 |
build_directory=build_dir,
|
| 161 |
verbose=False,
|
| 162 |
)
|
| 163 |
+
except Exception as exc:
|
| 164 |
os.environ.setdefault("CHIMERA_NATIVE_DISABLED", str(exc)[:200])
|
| 165 |
_NATIVE_EXT = None
|
| 166 |
return _NATIVE_EXT
|
| 167 |
|
| 168 |
|
| 169 |
def enable_native_kernel(force: bool = False) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
global _NATIVE_TRIED
|
| 171 |
if force:
|
| 172 |
_NATIVE_TRIED = False
|
|
|
|
| 177 |
return _NATIVE_EXT is not None
|
| 178 |
|
| 179 |
|
|
|
|
| 180 |
if os.environ.get("CHIMERA_NATIVE", "0") == "1":
|
| 181 |
enable_native_kernel()
|
| 182 |
|
| 183 |
|
| 184 |
# ---------------------------------------------------------------------------
|
| 185 |
+
# Pure PyTorch ternary primitives.
|
| 186 |
# ---------------------------------------------------------------------------
|
| 187 |
|
|
|
|
|
|
|
| 188 |
_TERNARY_LUT_F32 = torch.tensor([0.0, 1.0, -1.0, 0.0], dtype=torch.float32)
|
| 189 |
_TERNARY_LUT_I8 = torch.tensor([0, 1, -1, 0], dtype=torch.int8)
|
| 190 |
_SHIFTS = torch.tensor([6, 4, 2, 0], dtype=torch.uint8)
|
| 191 |
|
| 192 |
|
| 193 |
def pack_ternary(q: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
q = q.detach()
|
| 195 |
if q.dim() == 1:
|
| 196 |
q = q.unsqueeze(0)
|
|
|
|
| 200 |
pad = K4 * 4 - K
|
| 201 |
if pad:
|
| 202 |
flat = F.pad(flat, (0, pad))
|
|
|
|
| 203 |
codes = torch.where(flat == 1, torch.full_like(flat, 1),
|
| 204 |
torch.where(flat == -1, torch.full_like(flat, 2), torch.zeros_like(flat))).to(torch.uint8)
|
| 205 |
codes = codes.view(M, K4, 4)
|
|
|
|
| 211 |
def unpack_ternary(packed: torch.Tensor, k: int,
|
| 212 |
alpha: Optional[torch.Tensor] = None,
|
| 213 |
dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
packed = packed.to(torch.uint8)
|
| 215 |
if packed.dim() == 1:
|
| 216 |
packed = packed.unsqueeze(0)
|
| 217 |
flat = packed.reshape(-1, packed.shape[-1])
|
| 218 |
M, K4 = flat.shape
|
|
|
|
| 219 |
shifts = _SHIFTS.to(packed.device)
|
| 220 |
+
codes = (flat.unsqueeze(-1) >> shifts).bitwise_and_(3).to(torch.long)
|
| 221 |
lut = _TERNARY_LUT_F32.to(device=packed.device, dtype=dtype)
|
| 222 |
out = lut[codes].reshape(M, K4 * 4)[:, :k]
|
| 223 |
if alpha is not None:
|
|
|
|
| 226 |
|
| 227 |
|
| 228 |
def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
|
|
|
|
| 229 |
return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32)
|
| 230 |
|
| 231 |
|
| 232 |
def ternarize_weight(weight: torch.Tensor, group_size: int = 128
|
| 233 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
alpha = _absmean_alpha(weight)
|
| 235 |
w_q = torch.round(torch.clamp(weight / alpha.unsqueeze(-1), -1.0, 1.0)).to(torch.int8)
|
| 236 |
return w_q, alpha
|
| 237 |
|
| 238 |
|
| 239 |
+
_quantize_weights_ternary = ternarize_weight
|
| 240 |
|
| 241 |
|
| 242 |
def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 243 |
with torch.no_grad():
|
| 244 |
last = weight.shape[-1]
|
| 245 |
pad = (-last) % 4
|
| 246 |
target = F.pad(weight, (0, pad)) if pad else weight
|
| 247 |
view = target.view(*target.shape[:-1], -1, 4)
|
|
|
|
| 248 |
idx = view.abs().argsort(dim=-1)[..., :2]
|
| 249 |
view.scatter_(-1, idx, 0.0)
|
| 250 |
if pad:
|
|
|
|
| 256 |
# Straight-Through Estimator for ternary quantization.
|
| 257 |
# ---------------------------------------------------------------------------
|
| 258 |
#
|
| 259 |
+
# CLAMP-AWARE STE using the detach() trick:
|
| 260 |
#
|
| 261 |
+
# clamped = clamp(w, -1, 1)
|
| 262 |
+
# w_q = clamped + (round(clamped) - clamped).detach()
|
| 263 |
#
|
| 264 |
+
# Forward: evaluates to round(clamp(w, -1, 1)) — same as before.
|
| 265 |
+
# Backward: ∂/∂w [clamp(w, -1, 1)] = 1 if |w| <= 1 else 0.
|
| 266 |
+
# → Gradients are ZERO for weights outside [-1, 1] (at quantization boundary).
|
| 267 |
+
# → Gradients pass through unchanged inside [-1, 1] (STE identity).
|
| 268 |
#
|
| 269 |
+
# This prevents gradient explosion that caused NaN at step ~150 with the
|
| 270 |
+
# pure identity STE (w + (quant - w).detach()). The clamp derivative acts
|
| 271 |
+
# as a natural gradient gate: weights that have drifted beyond the ternary
|
| 272 |
+
# range get no gradient push, preventing runaway accumulation.
|
| 273 |
#
|
| 274 |
+
# Ref: 4-bit CPU training (arxiv:2603.13931) uses tanh soft clipping for
|
| 275 |
+
# the same stabilization purpose.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
# ---------------------------------------------------------------------------
|
| 277 |
|
|
|
|
|
|
|
|
|
|
| 278 |
class _RoundTernarySTE(torch.autograd.Function):
|
| 279 |
"""LEGACY — kept for backward compat. Use ste_ternary() instead."""
|
| 280 |
@staticmethod
|
|
|
|
| 290 |
"""Straight-through estimator for ternary quantization.
|
| 291 |
|
| 292 |
Forward: round(clamp(w, -1, 1))
|
| 293 |
+
Backward: clamp derivative (zero outside [-1, 1], identity inside)
|
| 294 |
|
| 295 |
Uses the detach() trick for zero graph breaks under torch.compile.
|
| 296 |
"""
|
| 297 |
+
clamped = torch.clamp(w, -1.0, 1.0)
|
| 298 |
+
w_q = torch.round(clamped)
|
| 299 |
+
return clamped + (w_q - clamped).detach()
|
| 300 |
|
| 301 |
|
| 302 |
# ---------------------------------------------------------------------------
|
| 303 |
+
# BitLinear
|
| 304 |
# ---------------------------------------------------------------------------
|
| 305 |
|
| 306 |
class BitLinear(nn.Module):
|
| 307 |
"""Linear layer with ternary {-1, 0, 1} weights and per-row absmean scale.
|
| 308 |
|
| 309 |
+
*Training*: STE ternarisation with clamp-aware gradient gating.
|
| 310 |
+
*Inference*: cached packed 2-bit uint8 weights.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
"""
|
| 312 |
|
| 313 |
__constants__ = ["in_features", "out_features", "use_2_4"]
|
|
|
|
| 326 |
else:
|
| 327 |
self.register_parameter("bias", None)
|
| 328 |
|
|
|
|
| 329 |
self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False)
|
| 330 |
self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False)
|
| 331 |
self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False)
|
|
|
|
| 335 |
|
| 336 |
self.reset_parameters()
|
| 337 |
|
|
|
|
|
|
|
| 338 |
def reset_parameters(self) -> None:
|
| 339 |
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 340 |
if self.bias is not None:
|
| 341 |
nn.init.zeros_(self.bias)
|
| 342 |
self._cache_version += 1
|
| 343 |
|
|
|
|
|
|
|
| 344 |
def invalidate_packed(self) -> None:
|
|
|
|
| 345 |
self._cache_version += 1
|
| 346 |
if self._dense_w.numel() > 0:
|
| 347 |
self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device)
|
| 348 |
self._dense_version = -1
|
| 349 |
|
| 350 |
def _quantize_latent(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
| 351 |
with torch.no_grad():
|
| 352 |
w = self.weight
|
| 353 |
alpha = _absmean_alpha(w)
|
|
|
|
| 372 |
|
| 373 |
@torch.no_grad()
|
| 374 |
def prepare_for_inference(self) -> None:
|
|
|
|
| 375 |
self.invalidate_packed()
|
| 376 |
self._ensure_packed()
|
| 377 |
|
| 378 |
@torch.no_grad()
|
| 379 |
def ternary_nonzero_mask(self) -> torch.Tensor:
|
|
|
|
| 380 |
self._ensure_packed()
|
| 381 |
ext = _NATIVE_EXT
|
| 382 |
if ext is not None:
|
|
|
|
| 385 |
w = unpack_ternary(self._packed, self.in_features)
|
| 386 |
return w.ne(0)
|
| 387 |
|
|
|
|
|
|
|
| 388 |
def _forward_train(self, x: torch.Tensor) -> torch.Tensor:
|
| 389 |
+
"""STE forward with clamp-aware gradient gating.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
+
The clamp on w_scaled ensures:
|
| 392 |
+
- Forward: round(clamp(w/alpha, -1, 1)) * alpha — correct ternary
|
| 393 |
+
- Backward: gradient is ZERO for w_scaled outside [-1, 1],
|
| 394 |
+
preventing gradient explosion from weights at the boundary.
|
| 395 |
"""
|
| 396 |
w = self.weight
|
| 397 |
alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5)
|
| 398 |
w_scaled = w / alpha
|
| 399 |
+
# Clamp FIRST, then detach the rounding residual.
|
| 400 |
+
# Gradient of clamp: 1 inside [-1,1], 0 outside → natural gradient gate
|
| 401 |
+
clamped = torch.clamp(w_scaled, -1.0, 1.0)
|
| 402 |
+
w_q = clamped + (torch.round(clamped) - clamped).detach()
|
| 403 |
w_q = w_q * alpha
|
| 404 |
if self.use_2_4:
|
| 405 |
with torch.no_grad():
|
|
|
|
| 408 |
return F.linear(x, w_q.to(x.dtype), self.bias)
|
| 409 |
|
| 410 |
def _ensure_dense(self) -> torch.Tensor:
|
|
|
|
| 411 |
self._ensure_packed()
|
| 412 |
if self._dense_version == self._cache_version and self._dense_w.numel() > 0:
|
| 413 |
return self._dense_w
|
|
|
|
| 421 |
return self._dense_w
|
| 422 |
|
| 423 |
def _forward_packed(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 424 |
w = self._ensure_dense()
|
| 425 |
if x.dtype != w.dtype:
|
| 426 |
w_used = w.to(x.dtype)
|
|
|
|
| 433 |
return self._forward_train(x)
|
| 434 |
return self._forward_packed(x)
|
| 435 |
|
|
|
|
|
|
|
| 436 |
def extra_repr(self) -> str:
|
| 437 |
return (f"in_features={self.in_features}, out_features={self.out_features}, "
|
| 438 |
f"bias={self.bias is not None}, nm_2_4={self.use_2_4}, "
|
|
|
|
| 440 |
|
| 441 |
|
| 442 |
# ---------------------------------------------------------------------------
|
| 443 |
+
# RMSNorm
|
| 444 |
# ---------------------------------------------------------------------------
|
| 445 |
|
| 446 |
class RMSNorm(nn.Module):
|
|
|
|
|
|
|
| 447 |
__constants__ = ["dim", "eps"]
|
| 448 |
|
| 449 |
def __init__(self, dim: int, eps: float = 1e-6):
|