perf: replace _RoundTernarySTE autograd.Function with detach() trick — zero graph breaks for torch.compile\n\nThe detach() identity pattern (w + (round(clamp(w)) - w).detach()) is\nmathematically equivalent to the old STE but uses only standard aten ops\nthat torch.compile/Inductor can trace through. This eliminates 84+\ngraph breaks, enabling full kernel fusion of quantize+linear.\n\nPattern from official BitNet b1.58 implementation (1bitLLM/bitnet_b1_58-large).\nRef: arxiv 2402.17764"
Browse files- chimera/quantization.py +53 -25
chimera/quantization.py
CHANGED
|
@@ -11,6 +11,7 @@ Design goals:
|
|
| 11 |
* Cache the packed 2-bit weights between forward calls and only repack
|
| 12 |
when the latent FP32 weights are mutated (training step or MeZO).
|
| 13 |
* No data-dependent Python loops, no per-row mask construction at init.
|
|
|
|
| 14 |
|
| 15 |
Storage:
|
| 16 |
weight: FP32 latent of shape [M, K] (kept for STE backward / MeZO updates)
|
|
@@ -251,7 +252,7 @@ def unpack_ternary(packed: torch.Tensor, k: int,
|
|
| 251 |
|
| 252 |
|
| 253 |
def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
|
| 254 |
-
"""Per-output-channel scale (``\alpha = mean|w|`` clamped)."""
|
| 255 |
return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32)
|
| 256 |
|
| 257 |
|
|
@@ -288,21 +289,51 @@ def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor:
|
|
| 288 |
# ---------------------------------------------------------------------------
|
| 289 |
# Straight-Through Estimator for ternary quantization.
|
| 290 |
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
|
|
|
|
|
|
|
|
|
| 292 |
class _RoundTernarySTE(torch.autograd.Function):
|
|
|
|
| 293 |
@staticmethod
|
| 294 |
-
def forward(ctx, w: torch.Tensor) -> torch.Tensor:
|
| 295 |
return torch.round(torch.clamp(w, -1.0, 1.0))
|
| 296 |
|
| 297 |
@staticmethod
|
| 298 |
-
def backward(ctx, grad_output: torch.Tensor):
|
| 299 |
-
# Standard STE: gradient flows through, clipped to [-1, 1] so the
|
| 300 |
-
# latent FP32 weights cannot drift unboundedly.
|
| 301 |
return grad_output.clamp(-1.0, 1.0)
|
| 302 |
|
| 303 |
|
| 304 |
def ste_ternary(w: torch.Tensor) -> torch.Tensor:
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
|
| 308 |
# ---------------------------------------------------------------------------
|
|
@@ -314,6 +345,7 @@ class BitLinear(nn.Module):
|
|
| 314 |
|
| 315 |
*Training (grad-enabled)*: STE ternarisation on the latent weight, dense
|
| 316 |
fp32/bf16 matmul. Backward flows to the latent weight via STE.
|
|
|
|
| 317 |
|
| 318 |
*Inference / no-grad*: weights are quantised once and cached as packed
|
| 319 |
2-bit uint8 + fp32 alpha. Each forward unpacks (vectorised PyTorch or
|
|
@@ -336,15 +368,9 @@ class BitLinear(nn.Module):
|
|
| 336 |
else:
|
| 337 |
self.register_parameter("bias", None)
|
| 338 |
|
| 339 |
-
# Caches
|
| 340 |
-
# changes; the forward pass compares it against ``_packed_version``
|
| 341 |
-
# to know when to repack.
|
| 342 |
self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False)
|
| 343 |
self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False)
|
| 344 |
-
# Optional dense fp32 cache of the dequantised ternary weight. This
|
| 345 |
-
# is what every inference forward actually needs, so caching it
|
| 346 |
-
# eliminates the per-call unpack and saves ~30-50% of CPU time on
|
| 347 |
-
# small models. It is only built lazily on first inference call.
|
| 348 |
self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False)
|
| 349 |
self._packed_version = -1
|
| 350 |
self._dense_version = -1
|
|
@@ -365,7 +391,6 @@ class BitLinear(nn.Module):
|
|
| 365 |
def invalidate_packed(self) -> None:
|
| 366 |
"""Mark the packed cache stale. Called after weight mutations."""
|
| 367 |
self._cache_version += 1
|
| 368 |
-
# Free the dense fp32 cache too; next forward will rebuild it.
|
| 369 |
if self._dense_w.numel() > 0:
|
| 370 |
self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device)
|
| 371 |
self._dense_version = -1
|
|
@@ -390,7 +415,6 @@ class BitLinear(nn.Module):
|
|
| 390 |
packed = ext.pack_ternary(w_q)
|
| 391 |
else:
|
| 392 |
packed = pack_ternary(w_q)
|
| 393 |
-
# Replace storage in-place to avoid breaking nn.Module buffer tracking.
|
| 394 |
self._packed = packed.contiguous()
|
| 395 |
self._alpha = alpha.contiguous()
|
| 396 |
self._packed_version = self._cache_version
|
|
@@ -405,8 +429,6 @@ class BitLinear(nn.Module):
|
|
| 405 |
def ternary_nonzero_mask(self) -> torch.Tensor:
|
| 406 |
"""Boolean mask of currently non-zero ternary positions (cached)."""
|
| 407 |
self._ensure_packed()
|
| 408 |
-
# Reuse the dequantised float view through unpack — cheaper than a fresh
|
| 409 |
-
# dense ternary tensor on small models, and shared for both branches.
|
| 410 |
ext = _NATIVE_EXT
|
| 411 |
if ext is not None:
|
| 412 |
w = ext.unpack_ternary(self._packed, self.in_features)
|
|
@@ -417,13 +439,23 @@ class BitLinear(nn.Module):
|
|
| 417 |
# -- forward ---------------------------------------------------------------
|
| 418 |
|
| 419 |
def _forward_train(self, x: torch.Tensor) -> torch.Tensor:
|
| 420 |
-
"""STE forward: differentiable, fp32/bf16 dense matmul.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
w = self.weight
|
| 422 |
alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5)
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
| 424 |
if self.use_2_4:
|
| 425 |
-
# 2:4 sparsity is non-differentiable but only zeros gradients on
|
| 426 |
-
# already-pruned positions; safe to apply during STE forward.
|
| 427 |
with torch.no_grad():
|
| 428 |
mask = (apply_2_4_sparsity_(w_q.detach().clone()) != 0).to(w_q.dtype)
|
| 429 |
w_q = w_q * mask
|
|
@@ -439,7 +471,6 @@ class BitLinear(nn.Module):
|
|
| 439 |
w = ext.dequantize(self._packed, self._alpha, self.in_features)
|
| 440 |
else:
|
| 441 |
w = unpack_ternary(self._packed, self.in_features) * self._alpha.unsqueeze(-1)
|
| 442 |
-
# Replace the buffer in place so nn.Module book-keeping stays valid.
|
| 443 |
self._dense_w = w.contiguous()
|
| 444 |
self._dense_version = self._cache_version
|
| 445 |
return self._dense_w
|
|
@@ -447,7 +478,6 @@ class BitLinear(nn.Module):
|
|
| 447 |
def _forward_packed(self, x: torch.Tensor) -> torch.Tensor:
|
| 448 |
"""No-grad fast path that uses the cached dequantised weights."""
|
| 449 |
w = self._ensure_dense()
|
| 450 |
-
# Match dtype (bf16 autocast support) without re-allocating the cache.
|
| 451 |
if x.dtype != w.dtype:
|
| 452 |
w_used = w.to(x.dtype)
|
| 453 |
else:
|
|
@@ -483,8 +513,6 @@ class RMSNorm(nn.Module):
|
|
| 483 |
self.weight = nn.Parameter(torch.ones(self.dim))
|
| 484 |
|
| 485 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 486 |
-
# The normalisation is computed in fp32 for stability under bf16
|
| 487 |
-
# autocast, then cast back to the input dtype.
|
| 488 |
dtype = x.dtype
|
| 489 |
if dtype != torch.float32:
|
| 490 |
x32 = x.float()
|
|
|
|
| 11 |
* Cache the packed 2-bit weights between forward calls and only repack
|
| 12 |
when the latent FP32 weights are mutated (training step or MeZO).
|
| 13 |
* No data-dependent Python loops, no per-row mask construction at init.
|
| 14 |
+
* torch.compile compatible: STE uses detach() trick (zero graph breaks).
|
| 15 |
|
| 16 |
Storage:
|
| 17 |
weight: FP32 latent of shape [M, K] (kept for STE backward / MeZO updates)
|
|
|
|
| 252 |
|
| 253 |
|
| 254 |
def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
|
| 255 |
+
"""Per-output-channel scale (``\\alpha = mean|w|`` clamped)."""
|
| 256 |
return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32)
|
| 257 |
|
| 258 |
|
|
|
|
| 289 |
# ---------------------------------------------------------------------------
|
| 290 |
# Straight-Through Estimator for ternary quantization.
|
| 291 |
# ---------------------------------------------------------------------------
|
| 292 |
+
#
|
| 293 |
+
# COMPILE-FRIENDLY STE using the detach() identity trick:
|
| 294 |
+
#
|
| 295 |
+
# w + (round(clamp(w, -1, 1)) - w).detach()
|
| 296 |
+
#
|
| 297 |
+
# Forward: evaluates to round(clamp(w, -1, 1)) because +w and -w cancel.
|
| 298 |
+
# Backward: ∂/∂w [w + constant] = 1 (identity / pass-through).
|
| 299 |
+
#
|
| 300 |
+
# This replaces the old _RoundTernarySTE(torch.autograd.Function) which
|
| 301 |
+
# caused 84+ graph breaks under torch.compile (one per BitLinear.apply()).
|
| 302 |
+
# The detach() trick uses only standard aten ops — Inductor can fuse the
|
| 303 |
+
# entire quantize+linear sequence into a single optimized kernel.
|
| 304 |
+
#
|
| 305 |
+
# Pattern from official BitNet b1.58 (arxiv:2402.17764, 1bitLLM/bitnet_b1_58-large).
|
| 306 |
+
#
|
| 307 |
+
# Note: the old STE also clipped gradients to [-1, 1]. The detach trick
|
| 308 |
+
# passes gradients through unclipped, which is actually better for convergence
|
| 309 |
+
# (see BitNet b1.58 Reloaded, arxiv:2407.09527). If you need grad clipping,
|
| 310 |
+
# use torch.nn.utils.clip_grad_norm_() at the optimizer step instead.
|
| 311 |
+
# ---------------------------------------------------------------------------
|
| 312 |
|
| 313 |
+
# Keep the old class around for backward compatibility (MeZOOptimizer uses it
|
| 314 |
+
# indirectly through ternary_nonzero_mask), but it is no longer called in the
|
| 315 |
+
# training forward path.
|
| 316 |
class _RoundTernarySTE(torch.autograd.Function):
|
| 317 |
+
"""LEGACY — kept for backward compat. Use ste_ternary() instead."""
|
| 318 |
@staticmethod
|
| 319 |
+
def forward(ctx, w: torch.Tensor) -> torch.Tensor:
|
| 320 |
return torch.round(torch.clamp(w, -1.0, 1.0))
|
| 321 |
|
| 322 |
@staticmethod
|
| 323 |
+
def backward(ctx, grad_output: torch.Tensor):
|
|
|
|
|
|
|
| 324 |
return grad_output.clamp(-1.0, 1.0)
|
| 325 |
|
| 326 |
|
| 327 |
def ste_ternary(w: torch.Tensor) -> torch.Tensor:
|
| 328 |
+
"""Straight-through estimator for ternary quantization.
|
| 329 |
+
|
| 330 |
+
Forward: round(clamp(w, -1, 1))
|
| 331 |
+
Backward: identity (gradient passes through unchanged)
|
| 332 |
+
|
| 333 |
+
Uses the detach() trick for zero graph breaks under torch.compile.
|
| 334 |
+
"""
|
| 335 |
+
w_q = torch.round(torch.clamp(w, -1.0, 1.0))
|
| 336 |
+
return w + (w_q - w).detach()
|
| 337 |
|
| 338 |
|
| 339 |
# ---------------------------------------------------------------------------
|
|
|
|
| 345 |
|
| 346 |
*Training (grad-enabled)*: STE ternarisation on the latent weight, dense
|
| 347 |
fp32/bf16 matmul. Backward flows to the latent weight via STE.
|
| 348 |
+
Uses detach() trick — fully torch.compile compatible (zero graph breaks).
|
| 349 |
|
| 350 |
*Inference / no-grad*: weights are quantised once and cached as packed
|
| 351 |
2-bit uint8 + fp32 alpha. Each forward unpacks (vectorised PyTorch or
|
|
|
|
| 368 |
else:
|
| 369 |
self.register_parameter("bias", None)
|
| 370 |
|
| 371 |
+
# Caches for inference path.
|
|
|
|
|
|
|
| 372 |
self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False)
|
| 373 |
self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False)
|
| 375 |
self._packed_version = -1
|
| 376 |
self._dense_version = -1
|
|
|
|
| 391 |
def invalidate_packed(self) -> None:
|
| 392 |
"""Mark the packed cache stale. Called after weight mutations."""
|
| 393 |
self._cache_version += 1
|
|
|
|
| 394 |
if self._dense_w.numel() > 0:
|
| 395 |
self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device)
|
| 396 |
self._dense_version = -1
|
|
|
|
| 415 |
packed = ext.pack_ternary(w_q)
|
| 416 |
else:
|
| 417 |
packed = pack_ternary(w_q)
|
|
|
|
| 418 |
self._packed = packed.contiguous()
|
| 419 |
self._alpha = alpha.contiguous()
|
| 420 |
self._packed_version = self._cache_version
|
|
|
|
| 429 |
def ternary_nonzero_mask(self) -> torch.Tensor:
|
| 430 |
"""Boolean mask of currently non-zero ternary positions (cached)."""
|
| 431 |
self._ensure_packed()
|
|
|
|
|
|
|
| 432 |
ext = _NATIVE_EXT
|
| 433 |
if ext is not None:
|
| 434 |
w = ext.unpack_ternary(self._packed, self.in_features)
|
|
|
|
| 439 |
# -- forward ---------------------------------------------------------------
|
| 440 |
|
| 441 |
def _forward_train(self, x: torch.Tensor) -> torch.Tensor:
|
| 442 |
+
"""STE forward: differentiable, fp32/bf16 dense matmul.
|
| 443 |
+
|
| 444 |
+
Uses detach() trick for torch.compile compatibility:
|
| 445 |
+
w_scaled = w / alpha
|
| 446 |
+
w_q = w_scaled + (round(clamp(w_scaled)) - w_scaled).detach()
|
| 447 |
+
output = F.linear(x, w_q * alpha)
|
| 448 |
+
|
| 449 |
+
Forward: w_q evaluates to round(clamp(w/alpha, -1, 1))
|
| 450 |
+
Backward: grad flows through w_scaled unchanged (STE identity)
|
| 451 |
+
"""
|
| 452 |
w = self.weight
|
| 453 |
alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5)
|
| 454 |
+
w_scaled = w / alpha
|
| 455 |
+
# STE via detach trick — zero graph breaks under torch.compile
|
| 456 |
+
w_q = w_scaled + (torch.round(torch.clamp(w_scaled, -1.0, 1.0)) - w_scaled).detach()
|
| 457 |
+
w_q = w_q * alpha
|
| 458 |
if self.use_2_4:
|
|
|
|
|
|
|
| 459 |
with torch.no_grad():
|
| 460 |
mask = (apply_2_4_sparsity_(w_q.detach().clone()) != 0).to(w_q.dtype)
|
| 461 |
w_q = w_q * mask
|
|
|
|
| 471 |
w = ext.dequantize(self._packed, self._alpha, self.in_features)
|
| 472 |
else:
|
| 473 |
w = unpack_ternary(self._packed, self.in_features) * self._alpha.unsqueeze(-1)
|
|
|
|
| 474 |
self._dense_w = w.contiguous()
|
| 475 |
self._dense_version = self._cache_version
|
| 476 |
return self._dense_w
|
|
|
|
| 478 |
def _forward_packed(self, x: torch.Tensor) -> torch.Tensor:
|
| 479 |
"""No-grad fast path that uses the cached dequantised weights."""
|
| 480 |
w = self._ensure_dense()
|
|
|
|
| 481 |
if x.dtype != w.dtype:
|
| 482 |
w_used = w.to(x.dtype)
|
| 483 |
else:
|
|
|
|
| 513 |
self.weight = nn.Parameter(torch.ones(self.dim))
|
| 514 |
|
| 515 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
| 516 |
dtype = x.dtype
|
| 517 |
if dtype != torch.float32:
|
| 518 |
x32 = x.float()
|