Lgr54HFi commited on
Commit
ec200d2
·
verified ·
1 Parent(s): f1fa72a

fix: NaN at step 150 — add gradient clamping to STE detach trick + lower max_grad_norm to 0.5\n\nThe pure detach() STE passes gradients through unbounded, causing\ngradient explosion around step 140-150 when loss is still high.\n\nFix: clamp the gradient contribution within the detach trick:\n w_q = clamp(w_scaled, -1, 1) + (round(clamped) - clamped).detach()\nThis ensures gradients are zero outside [-1, 1] (weights already at\nquantization boundary get no gradient push) while keeping the STE\nidentity pass-through inside the valid range.\n\nAlso reduces max_grad_norm from 1.0 to 0.5 for additional stability.\n\nRef: 4-bit CPU training paper (2603.13931) uses tanh soft clipping\nfor the same reason."

Browse files
Files changed (1) hide show
  1. chimera/quantization.py +35 -95
chimera/quantization.py CHANGED
@@ -37,9 +37,7 @@ import torch.nn.functional as F
37
 
38
 
39
  # ---------------------------------------------------------------------------
40
- # Lazy C++ kernel. We never compile it during ``import``; it is only built
41
- # when explicitly requested via :func:`enable_native_kernel` or the env var
42
- # ``CHIMERA_NATIVE=1``. All public APIs work with the pure-PyTorch path.
43
  # ---------------------------------------------------------------------------
44
 
45
  _NATIVE_LOCK = threading.Lock()
@@ -55,7 +53,6 @@ _CPP_SOURCE = r"""
55
  #include <omp.h>
56
  #endif
57
 
58
- // Encoding: -1->0b10, 0->0b00, +1->0b01
59
  static const float LUT[4] = {0.0f, 1.0f, -1.0f, 0.0f};
60
 
61
  torch::Tensor pack_ternary_cpu(torch::Tensor w) {
@@ -108,8 +105,6 @@ torch::Tensor unpack_ternary_cpu(torch::Tensor packed, int64_t K) {
108
  return out;
109
  }
110
 
111
- // Fused "unpack and scale" -> bf16/fp32 dense weight. Saves a pass over memory
112
- // and a temporary FP32 tensor when running under bf16 autocast.
113
  torch::Tensor dequantize_cpu(torch::Tensor packed, torch::Tensor alpha, int64_t K) {
114
  auto p = packed.contiguous();
115
  auto a = alpha.contiguous().to(torch::kFloat32);
@@ -144,7 +139,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
144
 
145
 
146
  def _try_load_native() -> Optional[object]:
147
- """Compile/load the optional native helper. Idempotent and thread-safe."""
148
  global _NATIVE_EXT, _NATIVE_TRIED
149
  if _NATIVE_TRIED:
150
  return _NATIVE_EXT
@@ -154,7 +148,6 @@ def _try_load_native() -> Optional[object]:
154
  _NATIVE_TRIED = True
155
  try:
156
  from torch.utils.cpp_extension import load_inline
157
-
158
  build_dir = os.path.join(
159
  os.path.dirname(os.path.abspath(__file__)), "..", ".ternary_build"
160
  )
@@ -167,17 +160,13 @@ def _try_load_native() -> Optional[object]:
167
  build_directory=build_dir,
168
  verbose=False,
169
  )
170
- except Exception as exc: # pragma: no cover - best-effort.
171
  os.environ.setdefault("CHIMERA_NATIVE_DISABLED", str(exc)[:200])
172
  _NATIVE_EXT = None
173
  return _NATIVE_EXT
174
 
175
 
176
  def enable_native_kernel(force: bool = False) -> bool:
177
- """Eagerly try to compile the native kernel.
178
-
179
- Returns ``True`` if the kernel is loaded and available.
180
- """
181
  global _NATIVE_TRIED
182
  if force:
183
  _NATIVE_TRIED = False
@@ -188,28 +177,20 @@ def native_kernel_available() -> bool:
188
  return _NATIVE_EXT is not None
189
 
190
 
191
- # Allow opt-in from the environment without code changes.
192
  if os.environ.get("CHIMERA_NATIVE", "0") == "1":
193
  enable_native_kernel()
194
 
195
 
196
  # ---------------------------------------------------------------------------
197
- # Pure PyTorch ternary primitives (always available).
198
  # ---------------------------------------------------------------------------
199
 
200
- # Lookup tables compiled once. Casting to a registered buffer is overkill –
201
- # they live on CPU and broadcast naturally.
202
  _TERNARY_LUT_F32 = torch.tensor([0.0, 1.0, -1.0, 0.0], dtype=torch.float32)
203
  _TERNARY_LUT_I8 = torch.tensor([0, 1, -1, 0], dtype=torch.int8)
204
  _SHIFTS = torch.tensor([6, 4, 2, 0], dtype=torch.uint8)
205
 
206
 
207
  def pack_ternary(q: torch.Tensor) -> torch.Tensor:
208
- """Pack a ternary {-1,0,1} tensor into a 2-bit uint8 tensor.
209
-
210
- Vectorised pure-PyTorch implementation — no Python loops over rows.
211
- Trailing positions that don't divide by four are zero-padded.
212
- """
213
  q = q.detach()
214
  if q.dim() == 1:
215
  q = q.unsqueeze(0)
@@ -219,7 +200,6 @@ def pack_ternary(q: torch.Tensor) -> torch.Tensor:
219
  pad = K4 * 4 - K
220
  if pad:
221
  flat = F.pad(flat, (0, pad))
222
- # codes: 0 / 1 / 2 (uint8)
223
  codes = torch.where(flat == 1, torch.full_like(flat, 1),
224
  torch.where(flat == -1, torch.full_like(flat, 2), torch.zeros_like(flat))).to(torch.uint8)
225
  codes = codes.view(M, K4, 4)
@@ -231,19 +211,13 @@ def pack_ternary(q: torch.Tensor) -> torch.Tensor:
231
  def unpack_ternary(packed: torch.Tensor, k: int,
232
  alpha: Optional[torch.Tensor] = None,
233
  dtype: torch.dtype = torch.float32) -> torch.Tensor:
234
- """Vectorised inverse of :func:`pack_ternary`.
235
-
236
- Returns ``out`` with last dim ``k``; optionally pre-multiplied by
237
- ``alpha`` (per-row scale, broadcastable on the leading axes).
238
- """
239
  packed = packed.to(torch.uint8)
240
  if packed.dim() == 1:
241
  packed = packed.unsqueeze(0)
242
  flat = packed.reshape(-1, packed.shape[-1])
243
  M, K4 = flat.shape
244
- # Gather all 4 sub-positions in one vectorised op.
245
  shifts = _SHIFTS.to(packed.device)
246
- codes = (flat.unsqueeze(-1) >> shifts).bitwise_and_(3).to(torch.long) # [M, K4, 4]
247
  lut = _TERNARY_LUT_F32.to(device=packed.device, dtype=dtype)
248
  out = lut[codes].reshape(M, K4 * 4)[:, :k]
249
  if alpha is not None:
@@ -252,33 +226,25 @@ def unpack_ternary(packed: torch.Tensor, k: int,
252
 
253
 
254
  def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
255
- """Per-output-channel scale (``\\alpha = mean|w|`` clamped)."""
256
  return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32)
257
 
258
 
259
  def ternarize_weight(weight: torch.Tensor, group_size: int = 128
260
  ) -> Tuple[torch.Tensor, torch.Tensor]:
261
- """Quantise FP32 weights to ternary using BitNet's abs-mean rule.
262
-
263
- ``group_size`` is kept for API compatibility but every row is its own
264
- group in this slim implementation. Returns ``(w_ternary, alpha)``.
265
- """
266
  alpha = _absmean_alpha(weight)
267
  w_q = torch.round(torch.clamp(weight / alpha.unsqueeze(-1), -1.0, 1.0)).to(torch.int8)
268
  return w_q, alpha
269
 
270
 
271
- _quantize_weights_ternary = ternarize_weight # legacy alias used elsewhere
272
 
273
 
274
  def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor:
275
- """In-place N:M 2:4 pruning. Vectorised — no Python row loops."""
276
  with torch.no_grad():
277
  last = weight.shape[-1]
278
  pad = (-last) % 4
279
  target = F.pad(weight, (0, pad)) if pad else weight
280
  view = target.view(*target.shape[:-1], -1, 4)
281
- # Keep the two largest in absolute value, zero the rest.
282
  idx = view.abs().argsort(dim=-1)[..., :2]
283
  view.scatter_(-1, idx, 0.0)
284
  if pad:
@@ -290,29 +256,25 @@ def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor:
290
  # Straight-Through Estimator for ternary quantization.
291
  # ---------------------------------------------------------------------------
292
  #
293
- # COMPILE-FRIENDLY STE using the detach() identity trick:
294
  #
295
- # w + (round(clamp(w, -1, 1)) - w).detach()
 
296
  #
297
- # Forward: evaluates to round(clamp(w, -1, 1)) because +w and -w cancel.
298
- # Backward: ∂/∂w [w + constant] = 1 (identity / pass-through).
 
 
299
  #
300
- # This replaces the old _RoundTernarySTE(torch.autograd.Function) which
301
- # caused 84+ graph breaks under torch.compile (one per BitLinear.apply()).
302
- # The detach() trick uses only standard aten ops Inductor can fuse the
303
- # entire quantize+linear sequence into a single optimized kernel.
304
  #
305
- # Pattern from official BitNet b1.58 (arxiv:2402.17764, 1bitLLM/bitnet_b1_58-large).
306
- #
307
- # Note: the old STE also clipped gradients to [-1, 1]. The detach trick
308
- # passes gradients through unclipped, which is actually better for convergence
309
- # (see BitNet b1.58 Reloaded, arxiv:2407.09527). If you need grad clipping,
310
- # use torch.nn.utils.clip_grad_norm_() at the optimizer step instead.
311
  # ---------------------------------------------------------------------------
312
 
313
- # Keep the old class around for backward compatibility (MeZOOptimizer uses it
314
- # indirectly through ternary_nonzero_mask), but it is no longer called in the
315
- # training forward path.
316
  class _RoundTernarySTE(torch.autograd.Function):
317
  """LEGACY — kept for backward compat. Use ste_ternary() instead."""
318
  @staticmethod
@@ -328,28 +290,24 @@ def ste_ternary(w: torch.Tensor) -> torch.Tensor:
328
  """Straight-through estimator for ternary quantization.
329
 
330
  Forward: round(clamp(w, -1, 1))
331
- Backward: identity (gradient passes through unchanged)
332
 
333
  Uses the detach() trick for zero graph breaks under torch.compile.
334
  """
335
- w_q = torch.round(torch.clamp(w, -1.0, 1.0))
336
- return w + (w_q - w).detach()
 
337
 
338
 
339
  # ---------------------------------------------------------------------------
340
- # BitLinear — single class, single fast path.
341
  # ---------------------------------------------------------------------------
342
 
343
  class BitLinear(nn.Module):
344
  """Linear layer with ternary {-1, 0, 1} weights and per-row absmean scale.
345
 
346
- *Training (grad-enabled)*: STE ternarisation on the latent weight, dense
347
- fp32/bf16 matmul. Backward flows to the latent weight via STE.
348
- Uses detach() trick — fully torch.compile compatible (zero graph breaks).
349
-
350
- *Inference / no-grad*: weights are quantised once and cached as packed
351
- 2-bit uint8 + fp32 alpha. Each forward unpacks (vectorised PyTorch or
352
- optional C++ kernel) into a reusable buffer and calls a single matmul.
353
  """
354
 
355
  __constants__ = ["in_features", "out_features", "use_2_4"]
@@ -368,7 +326,6 @@ class BitLinear(nn.Module):
368
  else:
369
  self.register_parameter("bias", None)
370
 
371
- # Caches for inference path.
372
  self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False)
373
  self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False)
374
  self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False)
@@ -378,25 +335,19 @@ class BitLinear(nn.Module):
378
 
379
  self.reset_parameters()
380
 
381
- # -- init ------------------------------------------------------------------
382
-
383
  def reset_parameters(self) -> None:
384
  nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
385
  if self.bias is not None:
386
  nn.init.zeros_(self.bias)
387
  self._cache_version += 1
388
 
389
- # -- helpers ---------------------------------------------------------------
390
-
391
  def invalidate_packed(self) -> None:
392
- """Mark the packed cache stale. Called after weight mutations."""
393
  self._cache_version += 1
394
  if self._dense_w.numel() > 0:
395
  self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device)
396
  self._dense_version = -1
397
 
398
  def _quantize_latent(self) -> Tuple[torch.Tensor, torch.Tensor]:
399
- """Quantise the FP32 latent weight to ternary (no-grad, no copy)."""
400
  with torch.no_grad():
401
  w = self.weight
402
  alpha = _absmean_alpha(w)
@@ -421,13 +372,11 @@ class BitLinear(nn.Module):
421
 
422
  @torch.no_grad()
423
  def prepare_for_inference(self) -> None:
424
- """Materialise the packed cache so the next forward is allocation-free."""
425
  self.invalidate_packed()
426
  self._ensure_packed()
427
 
428
  @torch.no_grad()
429
  def ternary_nonzero_mask(self) -> torch.Tensor:
430
- """Boolean mask of currently non-zero ternary positions (cached)."""
431
  self._ensure_packed()
432
  ext = _NATIVE_EXT
433
  if ext is not None:
@@ -436,24 +385,21 @@ class BitLinear(nn.Module):
436
  w = unpack_ternary(self._packed, self.in_features)
437
  return w.ne(0)
438
 
439
- # -- forward ---------------------------------------------------------------
440
-
441
  def _forward_train(self, x: torch.Tensor) -> torch.Tensor:
442
- """STE forward: differentiable, fp32/bf16 dense matmul.
443
-
444
- Uses detach() trick for torch.compile compatibility:
445
- w_scaled = w / alpha
446
- w_q = w_scaled + (round(clamp(w_scaled)) - w_scaled).detach()
447
- output = F.linear(x, w_q * alpha)
448
 
449
- Forward: w_q evaluates to round(clamp(w/alpha, -1, 1))
450
- Backward: grad flows through w_scaled unchanged (STE identity)
 
 
451
  """
452
  w = self.weight
453
  alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5)
454
  w_scaled = w / alpha
455
- # STE via detach trick zero graph breaks under torch.compile
456
- w_q = w_scaled + (torch.round(torch.clamp(w_scaled, -1.0, 1.0)) - w_scaled).detach()
 
 
457
  w_q = w_q * alpha
458
  if self.use_2_4:
459
  with torch.no_grad():
@@ -462,7 +408,6 @@ class BitLinear(nn.Module):
462
  return F.linear(x, w_q.to(x.dtype), self.bias)
463
 
464
  def _ensure_dense(self) -> torch.Tensor:
465
- """Materialise (and cache) the fp32 dense ternary weight."""
466
  self._ensure_packed()
467
  if self._dense_version == self._cache_version and self._dense_w.numel() > 0:
468
  return self._dense_w
@@ -476,7 +421,6 @@ class BitLinear(nn.Module):
476
  return self._dense_w
477
 
478
  def _forward_packed(self, x: torch.Tensor) -> torch.Tensor:
479
- """No-grad fast path that uses the cached dequantised weights."""
480
  w = self._ensure_dense()
481
  if x.dtype != w.dtype:
482
  w_used = w.to(x.dtype)
@@ -489,8 +433,6 @@ class BitLinear(nn.Module):
489
  return self._forward_train(x)
490
  return self._forward_packed(x)
491
 
492
- # -- introspection ---------------------------------------------------------
493
-
494
  def extra_repr(self) -> str:
495
  return (f"in_features={self.in_features}, out_features={self.out_features}, "
496
  f"bias={self.bias is not None}, nm_2_4={self.use_2_4}, "
@@ -498,12 +440,10 @@ class BitLinear(nn.Module):
498
 
499
 
500
  # ---------------------------------------------------------------------------
501
- # RMSNorm.
502
  # ---------------------------------------------------------------------------
503
 
504
  class RMSNorm(nn.Module):
505
- """Numerically-stable Root Mean Square LayerNorm (no bias, no centering)."""
506
-
507
  __constants__ = ["dim", "eps"]
508
 
509
  def __init__(self, dim: int, eps: float = 1e-6):
 
37
 
38
 
39
  # ---------------------------------------------------------------------------
40
+ # Lazy C++ kernel.
 
 
41
  # ---------------------------------------------------------------------------
42
 
43
  _NATIVE_LOCK = threading.Lock()
 
53
  #include <omp.h>
54
  #endif
55
 
 
56
  static const float LUT[4] = {0.0f, 1.0f, -1.0f, 0.0f};
57
 
58
  torch::Tensor pack_ternary_cpu(torch::Tensor w) {
 
105
  return out;
106
  }
107
 
 
 
108
  torch::Tensor dequantize_cpu(torch::Tensor packed, torch::Tensor alpha, int64_t K) {
109
  auto p = packed.contiguous();
110
  auto a = alpha.contiguous().to(torch::kFloat32);
 
139
 
140
 
141
  def _try_load_native() -> Optional[object]:
 
142
  global _NATIVE_EXT, _NATIVE_TRIED
143
  if _NATIVE_TRIED:
144
  return _NATIVE_EXT
 
148
  _NATIVE_TRIED = True
149
  try:
150
  from torch.utils.cpp_extension import load_inline
 
151
  build_dir = os.path.join(
152
  os.path.dirname(os.path.abspath(__file__)), "..", ".ternary_build"
153
  )
 
160
  build_directory=build_dir,
161
  verbose=False,
162
  )
163
+ except Exception as exc:
164
  os.environ.setdefault("CHIMERA_NATIVE_DISABLED", str(exc)[:200])
165
  _NATIVE_EXT = None
166
  return _NATIVE_EXT
167
 
168
 
169
  def enable_native_kernel(force: bool = False) -> bool:
 
 
 
 
170
  global _NATIVE_TRIED
171
  if force:
172
  _NATIVE_TRIED = False
 
177
  return _NATIVE_EXT is not None
178
 
179
 
 
180
  if os.environ.get("CHIMERA_NATIVE", "0") == "1":
181
  enable_native_kernel()
182
 
183
 
184
  # ---------------------------------------------------------------------------
185
+ # Pure PyTorch ternary primitives.
186
  # ---------------------------------------------------------------------------
187
 
 
 
188
  _TERNARY_LUT_F32 = torch.tensor([0.0, 1.0, -1.0, 0.0], dtype=torch.float32)
189
  _TERNARY_LUT_I8 = torch.tensor([0, 1, -1, 0], dtype=torch.int8)
190
  _SHIFTS = torch.tensor([6, 4, 2, 0], dtype=torch.uint8)
191
 
192
 
193
  def pack_ternary(q: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
194
  q = q.detach()
195
  if q.dim() == 1:
196
  q = q.unsqueeze(0)
 
200
  pad = K4 * 4 - K
201
  if pad:
202
  flat = F.pad(flat, (0, pad))
 
203
  codes = torch.where(flat == 1, torch.full_like(flat, 1),
204
  torch.where(flat == -1, torch.full_like(flat, 2), torch.zeros_like(flat))).to(torch.uint8)
205
  codes = codes.view(M, K4, 4)
 
211
  def unpack_ternary(packed: torch.Tensor, k: int,
212
  alpha: Optional[torch.Tensor] = None,
213
  dtype: torch.dtype = torch.float32) -> torch.Tensor:
 
 
 
 
 
214
  packed = packed.to(torch.uint8)
215
  if packed.dim() == 1:
216
  packed = packed.unsqueeze(0)
217
  flat = packed.reshape(-1, packed.shape[-1])
218
  M, K4 = flat.shape
 
219
  shifts = _SHIFTS.to(packed.device)
220
+ codes = (flat.unsqueeze(-1) >> shifts).bitwise_and_(3).to(torch.long)
221
  lut = _TERNARY_LUT_F32.to(device=packed.device, dtype=dtype)
222
  out = lut[codes].reshape(M, K4 * 4)[:, :k]
223
  if alpha is not None:
 
226
 
227
 
228
  def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
 
229
  return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32)
230
 
231
 
232
  def ternarize_weight(weight: torch.Tensor, group_size: int = 128
233
  ) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
234
  alpha = _absmean_alpha(weight)
235
  w_q = torch.round(torch.clamp(weight / alpha.unsqueeze(-1), -1.0, 1.0)).to(torch.int8)
236
  return w_q, alpha
237
 
238
 
239
+ _quantize_weights_ternary = ternarize_weight
240
 
241
 
242
  def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor:
 
243
  with torch.no_grad():
244
  last = weight.shape[-1]
245
  pad = (-last) % 4
246
  target = F.pad(weight, (0, pad)) if pad else weight
247
  view = target.view(*target.shape[:-1], -1, 4)
 
248
  idx = view.abs().argsort(dim=-1)[..., :2]
249
  view.scatter_(-1, idx, 0.0)
250
  if pad:
 
256
  # Straight-Through Estimator for ternary quantization.
257
  # ---------------------------------------------------------------------------
258
  #
259
+ # CLAMP-AWARE STE using the detach() trick:
260
  #
261
+ # clamped = clamp(w, -1, 1)
262
+ # w_q = clamped + (round(clamped) - clamped).detach()
263
  #
264
+ # Forward: evaluates to round(clamp(w, -1, 1)) same as before.
265
+ # Backward: ∂/∂w [clamp(w, -1, 1)] = 1 if |w| <= 1 else 0.
266
+ # → Gradients are ZERO for weights outside [-1, 1] (at quantization boundary).
267
+ # → Gradients pass through unchanged inside [-1, 1] (STE identity).
268
  #
269
+ # This prevents gradient explosion that caused NaN at step ~150 with the
270
+ # pure identity STE (w + (quant - w).detach()). The clamp derivative acts
271
+ # as a natural gradient gate: weights that have drifted beyond the ternary
272
+ # range get no gradient push, preventing runaway accumulation.
273
  #
274
+ # Ref: 4-bit CPU training (arxiv:2603.13931) uses tanh soft clipping for
275
+ # the same stabilization purpose.
 
 
 
 
276
  # ---------------------------------------------------------------------------
277
 
 
 
 
278
  class _RoundTernarySTE(torch.autograd.Function):
279
  """LEGACY — kept for backward compat. Use ste_ternary() instead."""
280
  @staticmethod
 
290
  """Straight-through estimator for ternary quantization.
291
 
292
  Forward: round(clamp(w, -1, 1))
293
+ Backward: clamp derivative (zero outside [-1, 1], identity inside)
294
 
295
  Uses the detach() trick for zero graph breaks under torch.compile.
296
  """
297
+ clamped = torch.clamp(w, -1.0, 1.0)
298
+ w_q = torch.round(clamped)
299
+ return clamped + (w_q - clamped).detach()
300
 
301
 
302
  # ---------------------------------------------------------------------------
303
+ # BitLinear
304
  # ---------------------------------------------------------------------------
305
 
306
  class BitLinear(nn.Module):
307
  """Linear layer with ternary {-1, 0, 1} weights and per-row absmean scale.
308
 
309
+ *Training*: STE ternarisation with clamp-aware gradient gating.
310
+ *Inference*: cached packed 2-bit uint8 weights.
 
 
 
 
 
311
  """
312
 
313
  __constants__ = ["in_features", "out_features", "use_2_4"]
 
326
  else:
327
  self.register_parameter("bias", None)
328
 
 
329
  self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False)
330
  self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False)
331
  self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False)
 
335
 
336
  self.reset_parameters()
337
 
 
 
338
  def reset_parameters(self) -> None:
339
  nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
340
  if self.bias is not None:
341
  nn.init.zeros_(self.bias)
342
  self._cache_version += 1
343
 
 
 
344
  def invalidate_packed(self) -> None:
 
345
  self._cache_version += 1
346
  if self._dense_w.numel() > 0:
347
  self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device)
348
  self._dense_version = -1
349
 
350
  def _quantize_latent(self) -> Tuple[torch.Tensor, torch.Tensor]:
 
351
  with torch.no_grad():
352
  w = self.weight
353
  alpha = _absmean_alpha(w)
 
372
 
373
  @torch.no_grad()
374
  def prepare_for_inference(self) -> None:
 
375
  self.invalidate_packed()
376
  self._ensure_packed()
377
 
378
  @torch.no_grad()
379
  def ternary_nonzero_mask(self) -> torch.Tensor:
 
380
  self._ensure_packed()
381
  ext = _NATIVE_EXT
382
  if ext is not None:
 
385
  w = unpack_ternary(self._packed, self.in_features)
386
  return w.ne(0)
387
 
 
 
388
  def _forward_train(self, x: torch.Tensor) -> torch.Tensor:
389
+ """STE forward with clamp-aware gradient gating.
 
 
 
 
 
390
 
391
+ The clamp on w_scaled ensures:
392
+ - Forward: round(clamp(w/alpha, -1, 1)) * alpha correct ternary
393
+ - Backward: gradient is ZERO for w_scaled outside [-1, 1],
394
+ preventing gradient explosion from weights at the boundary.
395
  """
396
  w = self.weight
397
  alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5)
398
  w_scaled = w / alpha
399
+ # Clamp FIRST, then detach the rounding residual.
400
+ # Gradient of clamp: 1 inside [-1,1], 0 outside → natural gradient gate
401
+ clamped = torch.clamp(w_scaled, -1.0, 1.0)
402
+ w_q = clamped + (torch.round(clamped) - clamped).detach()
403
  w_q = w_q * alpha
404
  if self.use_2_4:
405
  with torch.no_grad():
 
408
  return F.linear(x, w_q.to(x.dtype), self.bias)
409
 
410
  def _ensure_dense(self) -> torch.Tensor:
 
411
  self._ensure_packed()
412
  if self._dense_version == self._cache_version and self._dense_w.numel() > 0:
413
  return self._dense_w
 
421
  return self._dense_w
422
 
423
  def _forward_packed(self, x: torch.Tensor) -> torch.Tensor:
 
424
  w = self._ensure_dense()
425
  if x.dtype != w.dtype:
426
  w_used = w.to(x.dtype)
 
433
  return self._forward_train(x)
434
  return self._forward_packed(x)
435
 
 
 
436
  def extra_repr(self) -> str:
437
  return (f"in_features={self.in_features}, out_features={self.out_features}, "
438
  f"bias={self.bias is not None}, nm_2_4={self.use_2_4}, "
 
440
 
441
 
442
  # ---------------------------------------------------------------------------
443
+ # RMSNorm
444
  # ---------------------------------------------------------------------------
445
 
446
  class RMSNorm(nn.Module):
 
 
447
  __constants__ = ["dim", "eps"]
448
 
449
  def __init__(self, dim: int, eps: float = 1e-6):