Kernels
dongseokmotif Claude Sonnet 4.6 wyldecat github-actions[bot] commited on
Commit
e8e2c81
Β·
unverified Β·
1 Parent(s): 313d56a

feat: extend QK-Clip to support MLA (MuonClip Algorithm 1) [skip-build] (#28)

Browse files

* feat: extend QK-Clip to support MLA (MuonClip Algorithm 1) [skip-build]

Add MLA support to qk_clip.py following the MuonClip spec (docs/muon-clip.md):
- parse_qk_layer: recognize 'wq_b' and 'wkv_b' MLA weight names
- QKClipInfo: add is_mla, qk_nope_head_dim, qk_rope_head_dim, v_head_dim fields
- get_qk_clip_info: branch on is_mla flag in clip_config
- compute_scales: use kv_stride (qk_nope + v_head_dim) as effective head dim for wkv_b
- qk_clip: simplify signature to (p, scales, info); vectorize MLA sub-region
scaling via tensor reshape instead of Python per-head loops:
wq_b: q_nope rows β†’ √γ, q_pe rows β†’ Ξ³
wkv_b: k_nope rows β†’ √γ, v rows unchanged

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style: ruff format (pre-commit fix)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Add built binary [skip-build]

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: WyldeCat <skan1543@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. build/torch210-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  2. build/torch210-cxx11-cu126-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} +1 -1
  3. build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py +10 -13
  4. build/torch210-cxx11-cu126-x86_64-linux/muon.py +8 -4
  5. build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py +23 -19
  6. build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py +87 -24
  7. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  8. build/torch210-cxx11-cu128-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} +1 -1
  9. build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py +10 -13
  10. build/torch210-cxx11-cu128-x86_64-linux/muon.py +8 -4
  11. build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py +23 -19
  12. build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py +87 -24
  13. build/torch210-cxx11-cu130-x86_64-linux/_ops.py +3 -3
  14. build/torch210-cxx11-cu130-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} +1 -1
  15. build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py +10 -13
  16. build/torch210-cxx11-cu130-x86_64-linux/muon.py +8 -4
  17. build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py +23 -19
  18. build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py +87 -24
  19. build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +3 -3
  20. build/torch210-cxx11-rocm70-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} +1 -1
  21. build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py +10 -13
  22. build/torch210-cxx11-rocm70-x86_64-linux/muon.py +8 -4
  23. build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py +23 -19
  24. build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py +87 -24
  25. build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +3 -3
  26. build/torch210-cxx11-rocm71-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} +1 -1
  27. build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py +10 -13
  28. build/torch210-cxx11-rocm71-x86_64-linux/muon.py +8 -4
  29. build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py +23 -19
  30. build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py +87 -24
  31. build/torch28-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  32. build/torch28-cxx11-cu126-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} +1 -1
  33. build/torch28-cxx11-cu126-x86_64-linux/cpu_offload.py +10 -13
  34. build/torch28-cxx11-cu126-x86_64-linux/muon.py +8 -4
  35. build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py +23 -19
  36. build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py +87 -24
  37. build/torch28-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  38. build/torch28-cxx11-cu128-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} +1 -1
  39. build/torch28-cxx11-cu128-x86_64-linux/cpu_offload.py +10 -13
  40. build/torch28-cxx11-cu128-x86_64-linux/muon.py +8 -4
  41. build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py +23 -19
  42. build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py +87 -24
  43. build/torch28-cxx11-cu129-x86_64-linux/_ops.py +3 -3
  44. build/torch28-cxx11-cu129-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} +1 -1
  45. build/torch28-cxx11-cu129-x86_64-linux/cpu_offload.py +10 -13
  46. build/torch28-cxx11-cu129-x86_64-linux/muon.py +8 -4
  47. build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py +23 -19
  48. build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py +87 -24
  49. build/torch28-cxx11-rocm63-x86_64-linux/_ops.py +3 -3
  50. build/torch28-cxx11-rocm63-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} +1 -1
build/torch210-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_38f9b8e_dirty
3
- ops = torch.ops._optimizer_38f9b8e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_38f9b8e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8d53b78_dirty
3
+ ops = torch.ops._optimizer_8d53b78_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8d53b78_dirty::{op_name}"
build/torch210-cxx11-cu126-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cb6163428ce86500d61c2b765eecd7eb6f31c092066278e1d1af7a0848dc5126
3
  size 1940944
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:075fc73dbb2750aed7598cc3e13b593b6b1e7a78a78491e1b852fbd2a9af8f8d
3
  size 1940944
build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py CHANGED
@@ -93,10 +93,7 @@ class CPUOffloadPool:
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
- cpu_flat = torch.empty(off,
97
- dtype=dtype,
98
- device="cpu",
99
- pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
@@ -140,8 +137,7 @@ class CPUOffloadPool:
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
- cpu_flat[off:off + n].copy_(local.reshape(-1),
144
- non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
@@ -159,8 +155,10 @@ class CPUOffloadPool:
159
  )
160
 
161
  if not self._logged:
162
- logger.info("[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
163
- offloaded_bytes / (1024**2))
 
 
164
 
165
  # ------------------------------------------------------------------
166
  def reload(self):
@@ -198,12 +196,11 @@ class CPUOffloadPool:
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
- local.reshape(-1).copy_(cpu_flat[off:off + n],
202
- non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
206
  if not self._logged:
207
- logger.info("[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)",
208
- reloaded_bytes / (1024**2))
209
- self._logged = True
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
97
  self._groups[dtype] = {
98
  "indices": indices,
99
  "offsets": offsets,
 
137
  for i, mgd_idx in enumerate(indices):
138
  local = self._local(self._managed[mgd_idx])
139
  off, n = offsets[i]
140
+ cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
 
141
 
142
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
143
 
 
155
  )
156
 
157
  if not self._logged:
158
+ logger.info(
159
+ "[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
160
+ offloaded_bytes / (1024**2),
161
+ )
162
 
163
  # ------------------------------------------------------------------
164
  def reload(self):
 
196
  for i, mgd_idx in enumerate(indices):
197
  local = self._local(self._managed[mgd_idx])
198
  off, n = offsets[i]
199
+ local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
 
200
 
201
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
202
 
203
  if not self._logged:
204
+ logger.info(
205
+ "[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)", reloaded_bytes / (1024**2)
206
+ )
build/torch210-cxx11-cu126-x86_64-linux/muon.py CHANGED
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
- qk_clip(p, scales_full, qk_clip_state.head_dim)
364
 
365
  def distributed_muon(
366
  self,
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
- qk_clip(p, scales_full, qk_clip_state.head_dim)
411
 
412
  if not dtensor_params:
413
  return
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
- raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
 
 
1054
  return super().state_dict()
1055
 
1056
  def load_state_dict(self, state_dict: dict) -> None:
1057
  if self.cpu_offload:
1058
- raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
  # Invalidate adamw.py's module-level tensor caches so that
 
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
+ qk_clip(p, scales_full, qk_clip_state)
364
 
365
  def distributed_muon(
366
  self,
 
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
+ qk_clip(p, scales_full, qk_clip_state)
411
 
412
  if not dtensor_params:
413
  return
 
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
+ raise RuntimeError(
1054
+ "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
1055
+ )
1056
  return super().state_dict()
1057
 
1058
  def load_state_dict(self, state_dict: dict) -> None:
1059
  if self.cpu_offload:
1060
+ raise RuntimeError(
1061
+ "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
1062
+ )
1063
  super().load_state_dict(state_dict)
1064
 
1065
  # Invalidate adamw.py's module-level tensor caches so that
build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py CHANGED
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
- LHS = np.array([
36
- [l, l**3, l**5, 1],
37
- [q, q**3, q**5, -1],
38
- [r, r**3, r**5, 1],
39
- [u, u**3, u**5, -1],
40
- ])
 
 
41
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
  if not np.all(np.isfinite([a, b, c, E])):
43
- raise ValueError(f"_optimal_quintic: non-finite solve result "
44
- f"a={a}, b={b}, c={c}, E={E}")
 
45
  q, r = np.sqrt(
46
- (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
47
- (10 * c))
48
  if not np.all(np.isfinite([q, r])):
49
- raise ValueError(
50
- f"_optimal_quintic: non-finite node update q={q}, r={r}")
51
  if abs(old_E - E) <= 1e-15:
52
  break
53
  else:
54
  raise RuntimeError(
55
- f"_optimal_quintic: did not converge after {max_iter} iterations")
 
56
  return float(a), float(b), float(c)
57
 
58
 
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
111
  # - Polar Express: analytically optimal per step, adapting to the shrinking
112
  # singular-value interval [l, u] as iterations progress; converges all
113
  # singular values to 1, producing the exact polar factor UV^T.
114
- _coeffs_list = _optimal_composition(l=1e-3,
115
- num_iters=10,
116
- safety_factor_eps=1e-2,
117
- cushion=0.02)
118
 
119
 
120
  # This code is adapted from:
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
148
 
149
  X = X / (X.norm() + 1e-7)
150
  hs = _coeffs_list[:steps] + list(
151
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
152
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
153
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  # Perform the NS iterations
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
183
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
184
 
185
  hs = _coeffs_list[:steps] + list(
186
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
187
  for a, b, c in hs:
188
  buf1 = torch.bmm(X, X.transpose(-2, -1))
189
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
 
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
+ LHS = np.array(
36
+ [
37
+ [l, l**3, l**5, 1],
38
+ [q, q**3, q**5, -1],
39
+ [r, r**3, r**5, 1],
40
+ [u, u**3, u**5, -1],
41
+ ]
42
+ )
43
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
44
  if not np.all(np.isfinite([a, b, c, E])):
45
+ raise ValueError(
46
+ f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
47
+ )
48
  q, r = np.sqrt(
49
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
50
+ )
51
  if not np.all(np.isfinite([q, r])):
52
+ raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
 
53
  if abs(old_E - E) <= 1e-15:
54
  break
55
  else:
56
  raise RuntimeError(
57
+ f"_optimal_quintic: did not converge after {max_iter} iterations"
58
+ )
59
  return float(a), float(b), float(c)
60
 
61
 
 
114
  # - Polar Express: analytically optimal per step, adapting to the shrinking
115
  # singular-value interval [l, u] as iterations progress; converges all
116
  # singular values to 1, producing the exact polar factor UV^T.
117
+ _coeffs_list = _optimal_composition(
118
+ l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
119
+ )
 
120
 
121
 
122
  # This code is adapted from:
 
150
 
151
  X = X / (X.norm() + 1e-7)
152
  hs = _coeffs_list[:steps] + list(
153
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
154
+ )
155
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
156
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
157
  # Perform the NS iterations
 
186
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
187
 
188
  hs = _coeffs_list[:steps] + list(
189
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
190
+ )
191
  for a, b, c in hs:
192
  buf1 = torch.bmm(X, X.transpose(-2, -1))
193
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py CHANGED
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
 
 
 
 
17
 
18
  Returns:
19
  (kind, layer_idx) or (None, -1) if not matched.
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
23
  'model.5.attn.wk.weight' -> ('wk', 5)
24
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
25
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
 
 
26
  'model.4.attn.v_proj.weight' -> (None, -1)
27
  """
28
  parts = normalize_fqn(name).split('.')
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
37
  layer_idx = int(part)
38
  break
39
 
40
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
41
  return kind, layer_idx
42
 
43
  return None, -1
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
46
  @dataclass
47
  class QKClipInfo:
48
  """Per-parameter dynamic info computed from config + runtime logits."""
49
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
50
  indices: list[int] # which heads to consider for clipping
51
- head_dim: int # from config
52
  threshold: float # from config
53
  logit: torch.Tensor | None
54
 
 
 
 
 
 
 
55
 
56
  def get_qk_clip_info(clip_config, n, qk_logits):
57
  """Extract QK clipping info for a named parameter.
58
 
59
  Args:
60
  clip_config: QK clipping configuration dict (or None).
 
 
61
  n: Parameter name string.
62
  qk_logits: Dict mapping layer indices to logit tensors (or None).
63
 
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
70
  head_dim = clip_config.get('head_dim')
71
  threshold = clip_config.get('threshold')
72
  kind, layer_idx = parse_qk_layer(n)
 
73
 
74
  logit, indices = None, []
75
  if qk_logits is not None and kind is not None:
76
  logit = qk_logits[layer_idx]
77
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
78
- indices = clip_config.get(indices_key, []) or []
79
-
80
  if isinstance(logit, DTensor):
81
  # In TP settings, qk_logits may be DTensor
82
  # We convert it to full tensor here for simplicity
83
  logit = logit.full_tensor()
84
 
85
- return QKClipInfo(
86
- kind=kind,
87
- indices=indices,
88
- head_dim=head_dim,
89
- threshold=threshold,
90
- logit=logit,
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  def compute_scales(p, qk_clip_state):
95
  """Compute per-head scaling factors for QK clipping.
96
 
97
- Returns scales tensor if any head exceeds threshold, else None.
 
98
  """
99
  kind = qk_clip_state.kind
100
  indices = qk_clip_state.indices
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
118
  if not head_scales:
119
  return None
120
 
121
- H_global = p.shape[0] // head_dim
 
 
 
 
 
 
122
  scales_full = torch.ones(H_global, device=p.data.device)
123
  for head_idx, scale in head_scales.items():
124
  scales_full[head_idx] = scale
125
  return scales_full
126
 
127
 
128
- def qk_clip(p, scales, head_dim):
129
- """Apply per-head scaling to a Q/K projection weight matrix."""
130
- if isinstance(p, torch.nn.Parameter):
131
- W = p.data.view(-1, head_dim, p.data.shape[1])
132
- W.mul_(scales.view(-1, 1, 1))
133
- else:
134
- W = p.view(-1, head_dim, p.shape[1])
135
- W.mul_(scales.view(-1, 1, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
+ and return (kind, layer_index).
17
+
18
+ Supported kinds:
19
+ MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
20
+ MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
21
 
22
  Returns:
23
  (kind, layer_idx) or (None, -1) if not matched.
 
27
  'model.5.attn.wk.weight' -> ('wk', 5)
28
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
29
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
30
+ 'model.1.attn.wq_b.weight' -> ('wq_b', 1)
31
+ 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
32
  'model.4.attn.v_proj.weight' -> (None, -1)
33
  """
34
  parts = normalize_fqn(name).split('.')
 
43
  layer_idx = int(part)
44
  break
45
 
46
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
47
  return kind, layer_idx
48
 
49
  return None, -1
 
52
  @dataclass
53
  class QKClipInfo:
54
  """Per-parameter dynamic info computed from config + runtime logits."""
55
+ kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
56
  indices: list[int] # which heads to consider for clipping
57
+ head_dim: int # from config (qk_head_dim for MLA wq_b)
58
  threshold: float # from config
59
  logit: torch.Tensor | None
60
 
61
+ # MLA-specific fields
62
+ is_mla: bool = False
63
+ qk_nope_head_dim: int = 0
64
+ qk_rope_head_dim: int = 0
65
+ v_head_dim: int = 0
66
+
67
 
68
  def get_qk_clip_info(clip_config, n, qk_logits):
69
  """Extract QK clipping info for a named parameter.
70
 
71
  Args:
72
  clip_config: QK clipping configuration dict (or None).
73
+ MHA/GQA keys: head_dim, threshold, q_indices, k_indices
74
+ MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
75
  n: Parameter name string.
76
  qk_logits: Dict mapping layer indices to logit tensors (or None).
77
 
 
84
  head_dim = clip_config.get('head_dim')
85
  threshold = clip_config.get('threshold')
86
  kind, layer_idx = parse_qk_layer(n)
87
+ is_mla = clip_config.get('is_mla', False)
88
 
89
  logit, indices = None, []
90
  if qk_logits is not None and kind is not None:
91
  logit = qk_logits[layer_idx]
 
 
 
92
  if isinstance(logit, DTensor):
93
  # In TP settings, qk_logits may be DTensor
94
  # We convert it to full tensor here for simplicity
95
  logit = logit.full_tensor()
96
 
97
+ if kind in ('wq_b', 'wq', 'q_proj'):
98
+ indices = clip_config.get('q_indices', []) or []
99
+ elif kind in ('wkv_b', 'wk', 'k_proj'):
100
+ indices = clip_config.get('k_indices', []) or []
101
+
102
+ if is_mla:
103
+ return QKClipInfo(
104
+ kind=kind,
105
+ indices=indices,
106
+ head_dim=head_dim,
107
+ threshold=threshold,
108
+ logit=logit,
109
+ is_mla=True,
110
+ qk_nope_head_dim=clip_config['qk_nope_head_dim'],
111
+ qk_rope_head_dim=clip_config['qk_rope_head_dim'],
112
+ v_head_dim=clip_config['v_head_dim'],
113
+ )
114
+ else:
115
+ return QKClipInfo(
116
+ kind=kind,
117
+ indices=indices,
118
+ head_dim=head_dim,
119
+ threshold=threshold,
120
+ logit=logit,
121
+ )
122
 
123
 
124
  def compute_scales(p, qk_clip_state):
125
  """Compute per-head scaling factors for QK clipping.
126
 
127
+ Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
128
+ For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
129
  """
130
  kind = qk_clip_state.kind
131
  indices = qk_clip_state.indices
 
149
  if not head_scales:
150
  return None
151
 
152
+ # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
153
+ if qk_clip_state.is_mla and kind == 'wkv_b':
154
+ effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
155
+ else:
156
+ effective_head_dim = head_dim
157
+
158
+ H_global = p.shape[0] // effective_head_dim
159
  scales_full = torch.ones(H_global, device=p.data.device)
160
  for head_idx, scale in head_scales.items():
161
  scales_full[head_idx] = scale
162
  return scales_full
163
 
164
 
165
+ def qk_clip(p, scales, info):
166
+ """Apply per-head scaling to a Q/K projection weight matrix.
167
+
168
+ Args:
169
+ p: Parameter (nn.Parameter or raw tensor).
170
+ scales: [n_heads] tensor, each element = √γ_h.
171
+ info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
172
+
173
+ MLA sub-region scaling per Algorithm 1 (MuonClip):
174
+ wq_b: q_nope rows β†’ √γ, q_pe rows β†’ Ξ³
175
+ wkv_b: k_nope rows β†’ √γ, v rows β†’ unchanged
176
+ """
177
+ W = p.data if isinstance(p, torch.nn.Parameter) else p
178
+
179
+ if not info.is_mla:
180
+ # MHA/GQA: uniform √γ applied to all rows in each head
181
+ W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
182
+ return
183
+
184
+ # MLA: vectorized sub-region scaling within each head
185
+ if info.kind == 'wq_b':
186
+ qk_nope = info.qk_nope_head_dim
187
+ qk_head_dim = qk_nope + info.qk_rope_head_dim
188
+ W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
189
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β†’ √γ
190
+ W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
191
+ 1)) # q_pe β†’ Ξ³
192
+
193
+ elif info.kind == 'wkv_b':
194
+ qk_nope = info.qk_nope_head_dim
195
+ kv_stride = qk_nope + info.v_head_dim
196
+ W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
197
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β†’ √γ
198
+ # v rows: not touched (k_R shared rotary unchanged)
build/torch210-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_38f9b8e_dirty
3
- ops = torch.ops._optimizer_38f9b8e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_38f9b8e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8d53b78_dirty
3
+ ops = torch.ops._optimizer_8d53b78_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8d53b78_dirty::{op_name}"
build/torch210-cxx11-cu128-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:017323d479e8fbd3ed1f550f95fc4ba9f2e304dbe9351c0eaa75543ebe775e18
3
  size 2004144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2af397ae01c8c01ee0e879f6812bd9df55d152afbcc6713f5c1987d5bce7793b
3
  size 2004144
build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py CHANGED
@@ -93,10 +93,7 @@ class CPUOffloadPool:
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
- cpu_flat = torch.empty(off,
97
- dtype=dtype,
98
- device="cpu",
99
- pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
@@ -140,8 +137,7 @@ class CPUOffloadPool:
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
- cpu_flat[off:off + n].copy_(local.reshape(-1),
144
- non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
@@ -159,8 +155,10 @@ class CPUOffloadPool:
159
  )
160
 
161
  if not self._logged:
162
- logger.info("[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
163
- offloaded_bytes / (1024**2))
 
 
164
 
165
  # ------------------------------------------------------------------
166
  def reload(self):
@@ -198,12 +196,11 @@ class CPUOffloadPool:
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
- local.reshape(-1).copy_(cpu_flat[off:off + n],
202
- non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
206
  if not self._logged:
207
- logger.info("[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)",
208
- reloaded_bytes / (1024**2))
209
- self._logged = True
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
97
  self._groups[dtype] = {
98
  "indices": indices,
99
  "offsets": offsets,
 
137
  for i, mgd_idx in enumerate(indices):
138
  local = self._local(self._managed[mgd_idx])
139
  off, n = offsets[i]
140
+ cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
 
141
 
142
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
143
 
 
155
  )
156
 
157
  if not self._logged:
158
+ logger.info(
159
+ "[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
160
+ offloaded_bytes / (1024**2),
161
+ )
162
 
163
  # ------------------------------------------------------------------
164
  def reload(self):
 
196
  for i, mgd_idx in enumerate(indices):
197
  local = self._local(self._managed[mgd_idx])
198
  off, n = offsets[i]
199
+ local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
 
200
 
201
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
202
 
203
  if not self._logged:
204
+ logger.info(
205
+ "[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)", reloaded_bytes / (1024**2)
206
+ )
build/torch210-cxx11-cu128-x86_64-linux/muon.py CHANGED
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
- qk_clip(p, scales_full, qk_clip_state.head_dim)
364
 
365
  def distributed_muon(
366
  self,
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
- qk_clip(p, scales_full, qk_clip_state.head_dim)
411
 
412
  if not dtensor_params:
413
  return
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
- raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
 
 
1054
  return super().state_dict()
1055
 
1056
  def load_state_dict(self, state_dict: dict) -> None:
1057
  if self.cpu_offload:
1058
- raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
  # Invalidate adamw.py's module-level tensor caches so that
 
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
+ qk_clip(p, scales_full, qk_clip_state)
364
 
365
  def distributed_muon(
366
  self,
 
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
+ qk_clip(p, scales_full, qk_clip_state)
411
 
412
  if not dtensor_params:
413
  return
 
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
+ raise RuntimeError(
1054
+ "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
1055
+ )
1056
  return super().state_dict()
1057
 
1058
  def load_state_dict(self, state_dict: dict) -> None:
1059
  if self.cpu_offload:
1060
+ raise RuntimeError(
1061
+ "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
1062
+ )
1063
  super().load_state_dict(state_dict)
1064
 
1065
  # Invalidate adamw.py's module-level tensor caches so that
build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py CHANGED
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
- LHS = np.array([
36
- [l, l**3, l**5, 1],
37
- [q, q**3, q**5, -1],
38
- [r, r**3, r**5, 1],
39
- [u, u**3, u**5, -1],
40
- ])
 
 
41
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
  if not np.all(np.isfinite([a, b, c, E])):
43
- raise ValueError(f"_optimal_quintic: non-finite solve result "
44
- f"a={a}, b={b}, c={c}, E={E}")
 
45
  q, r = np.sqrt(
46
- (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
47
- (10 * c))
48
  if not np.all(np.isfinite([q, r])):
49
- raise ValueError(
50
- f"_optimal_quintic: non-finite node update q={q}, r={r}")
51
  if abs(old_E - E) <= 1e-15:
52
  break
53
  else:
54
  raise RuntimeError(
55
- f"_optimal_quintic: did not converge after {max_iter} iterations")
 
56
  return float(a), float(b), float(c)
57
 
58
 
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
111
  # - Polar Express: analytically optimal per step, adapting to the shrinking
112
  # singular-value interval [l, u] as iterations progress; converges all
113
  # singular values to 1, producing the exact polar factor UV^T.
114
- _coeffs_list = _optimal_composition(l=1e-3,
115
- num_iters=10,
116
- safety_factor_eps=1e-2,
117
- cushion=0.02)
118
 
119
 
120
  # This code is adapted from:
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
148
 
149
  X = X / (X.norm() + 1e-7)
150
  hs = _coeffs_list[:steps] + list(
151
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
152
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
153
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  # Perform the NS iterations
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
183
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
184
 
185
  hs = _coeffs_list[:steps] + list(
186
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
187
  for a, b, c in hs:
188
  buf1 = torch.bmm(X, X.transpose(-2, -1))
189
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
 
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
+ LHS = np.array(
36
+ [
37
+ [l, l**3, l**5, 1],
38
+ [q, q**3, q**5, -1],
39
+ [r, r**3, r**5, 1],
40
+ [u, u**3, u**5, -1],
41
+ ]
42
+ )
43
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
44
  if not np.all(np.isfinite([a, b, c, E])):
45
+ raise ValueError(
46
+ f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
47
+ )
48
  q, r = np.sqrt(
49
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
50
+ )
51
  if not np.all(np.isfinite([q, r])):
52
+ raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
 
53
  if abs(old_E - E) <= 1e-15:
54
  break
55
  else:
56
  raise RuntimeError(
57
+ f"_optimal_quintic: did not converge after {max_iter} iterations"
58
+ )
59
  return float(a), float(b), float(c)
60
 
61
 
 
114
  # - Polar Express: analytically optimal per step, adapting to the shrinking
115
  # singular-value interval [l, u] as iterations progress; converges all
116
  # singular values to 1, producing the exact polar factor UV^T.
117
+ _coeffs_list = _optimal_composition(
118
+ l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
119
+ )
 
120
 
121
 
122
  # This code is adapted from:
 
150
 
151
  X = X / (X.norm() + 1e-7)
152
  hs = _coeffs_list[:steps] + list(
153
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
154
+ )
155
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
156
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
157
  # Perform the NS iterations
 
186
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
187
 
188
  hs = _coeffs_list[:steps] + list(
189
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
190
+ )
191
  for a, b, c in hs:
192
  buf1 = torch.bmm(X, X.transpose(-2, -1))
193
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py CHANGED
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
 
 
 
 
17
 
18
  Returns:
19
  (kind, layer_idx) or (None, -1) if not matched.
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
23
  'model.5.attn.wk.weight' -> ('wk', 5)
24
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
25
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
 
 
26
  'model.4.attn.v_proj.weight' -> (None, -1)
27
  """
28
  parts = normalize_fqn(name).split('.')
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
37
  layer_idx = int(part)
38
  break
39
 
40
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
41
  return kind, layer_idx
42
 
43
  return None, -1
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
46
  @dataclass
47
  class QKClipInfo:
48
  """Per-parameter dynamic info computed from config + runtime logits."""
49
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
50
  indices: list[int] # which heads to consider for clipping
51
- head_dim: int # from config
52
  threshold: float # from config
53
  logit: torch.Tensor | None
54
 
 
 
 
 
 
 
55
 
56
  def get_qk_clip_info(clip_config, n, qk_logits):
57
  """Extract QK clipping info for a named parameter.
58
 
59
  Args:
60
  clip_config: QK clipping configuration dict (or None).
 
 
61
  n: Parameter name string.
62
  qk_logits: Dict mapping layer indices to logit tensors (or None).
63
 
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
70
  head_dim = clip_config.get('head_dim')
71
  threshold = clip_config.get('threshold')
72
  kind, layer_idx = parse_qk_layer(n)
 
73
 
74
  logit, indices = None, []
75
  if qk_logits is not None and kind is not None:
76
  logit = qk_logits[layer_idx]
77
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
78
- indices = clip_config.get(indices_key, []) or []
79
-
80
  if isinstance(logit, DTensor):
81
  # In TP settings, qk_logits may be DTensor
82
  # We convert it to full tensor here for simplicity
83
  logit = logit.full_tensor()
84
 
85
- return QKClipInfo(
86
- kind=kind,
87
- indices=indices,
88
- head_dim=head_dim,
89
- threshold=threshold,
90
- logit=logit,
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  def compute_scales(p, qk_clip_state):
95
  """Compute per-head scaling factors for QK clipping.
96
 
97
- Returns scales tensor if any head exceeds threshold, else None.
 
98
  """
99
  kind = qk_clip_state.kind
100
  indices = qk_clip_state.indices
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
118
  if not head_scales:
119
  return None
120
 
121
- H_global = p.shape[0] // head_dim
 
 
 
 
 
 
122
  scales_full = torch.ones(H_global, device=p.data.device)
123
  for head_idx, scale in head_scales.items():
124
  scales_full[head_idx] = scale
125
  return scales_full
126
 
127
 
128
- def qk_clip(p, scales, head_dim):
129
- """Apply per-head scaling to a Q/K projection weight matrix."""
130
- if isinstance(p, torch.nn.Parameter):
131
- W = p.data.view(-1, head_dim, p.data.shape[1])
132
- W.mul_(scales.view(-1, 1, 1))
133
- else:
134
- W = p.view(-1, head_dim, p.shape[1])
135
- W.mul_(scales.view(-1, 1, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
+ and return (kind, layer_index).
17
+
18
+ Supported kinds:
19
+ MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
20
+ MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
21
 
22
  Returns:
23
  (kind, layer_idx) or (None, -1) if not matched.
 
27
  'model.5.attn.wk.weight' -> ('wk', 5)
28
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
29
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
30
+ 'model.1.attn.wq_b.weight' -> ('wq_b', 1)
31
+ 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
32
  'model.4.attn.v_proj.weight' -> (None, -1)
33
  """
34
  parts = normalize_fqn(name).split('.')
 
43
  layer_idx = int(part)
44
  break
45
 
46
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
47
  return kind, layer_idx
48
 
49
  return None, -1
 
52
  @dataclass
53
  class QKClipInfo:
54
  """Per-parameter dynamic info computed from config + runtime logits."""
55
+ kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
56
  indices: list[int] # which heads to consider for clipping
57
+ head_dim: int # from config (qk_head_dim for MLA wq_b)
58
  threshold: float # from config
59
  logit: torch.Tensor | None
60
 
61
+ # MLA-specific fields
62
+ is_mla: bool = False
63
+ qk_nope_head_dim: int = 0
64
+ qk_rope_head_dim: int = 0
65
+ v_head_dim: int = 0
66
+
67
 
68
  def get_qk_clip_info(clip_config, n, qk_logits):
69
  """Extract QK clipping info for a named parameter.
70
 
71
  Args:
72
  clip_config: QK clipping configuration dict (or None).
73
+ MHA/GQA keys: head_dim, threshold, q_indices, k_indices
74
+ MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
75
  n: Parameter name string.
76
  qk_logits: Dict mapping layer indices to logit tensors (or None).
77
 
 
84
  head_dim = clip_config.get('head_dim')
85
  threshold = clip_config.get('threshold')
86
  kind, layer_idx = parse_qk_layer(n)
87
+ is_mla = clip_config.get('is_mla', False)
88
 
89
  logit, indices = None, []
90
  if qk_logits is not None and kind is not None:
91
  logit = qk_logits[layer_idx]
 
 
 
92
  if isinstance(logit, DTensor):
93
  # In TP settings, qk_logits may be DTensor
94
  # We convert it to full tensor here for simplicity
95
  logit = logit.full_tensor()
96
 
97
+ if kind in ('wq_b', 'wq', 'q_proj'):
98
+ indices = clip_config.get('q_indices', []) or []
99
+ elif kind in ('wkv_b', 'wk', 'k_proj'):
100
+ indices = clip_config.get('k_indices', []) or []
101
+
102
+ if is_mla:
103
+ return QKClipInfo(
104
+ kind=kind,
105
+ indices=indices,
106
+ head_dim=head_dim,
107
+ threshold=threshold,
108
+ logit=logit,
109
+ is_mla=True,
110
+ qk_nope_head_dim=clip_config['qk_nope_head_dim'],
111
+ qk_rope_head_dim=clip_config['qk_rope_head_dim'],
112
+ v_head_dim=clip_config['v_head_dim'],
113
+ )
114
+ else:
115
+ return QKClipInfo(
116
+ kind=kind,
117
+ indices=indices,
118
+ head_dim=head_dim,
119
+ threshold=threshold,
120
+ logit=logit,
121
+ )
122
 
123
 
124
  def compute_scales(p, qk_clip_state):
125
  """Compute per-head scaling factors for QK clipping.
126
 
127
+ Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
128
+ For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
129
  """
130
  kind = qk_clip_state.kind
131
  indices = qk_clip_state.indices
 
149
  if not head_scales:
150
  return None
151
 
152
+ # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
153
+ if qk_clip_state.is_mla and kind == 'wkv_b':
154
+ effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
155
+ else:
156
+ effective_head_dim = head_dim
157
+
158
+ H_global = p.shape[0] // effective_head_dim
159
  scales_full = torch.ones(H_global, device=p.data.device)
160
  for head_idx, scale in head_scales.items():
161
  scales_full[head_idx] = scale
162
  return scales_full
163
 
164
 
165
+ def qk_clip(p, scales, info):
166
+ """Apply per-head scaling to a Q/K projection weight matrix.
167
+
168
+ Args:
169
+ p: Parameter (nn.Parameter or raw tensor).
170
+ scales: [n_heads] tensor, each element = √γ_h.
171
+ info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
172
+
173
+ MLA sub-region scaling per Algorithm 1 (MuonClip):
174
+ wq_b: q_nope rows β†’ √γ, q_pe rows β†’ Ξ³
175
+ wkv_b: k_nope rows β†’ √γ, v rows β†’ unchanged
176
+ """
177
+ W = p.data if isinstance(p, torch.nn.Parameter) else p
178
+
179
+ if not info.is_mla:
180
+ # MHA/GQA: uniform √γ applied to all rows in each head
181
+ W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
182
+ return
183
+
184
+ # MLA: vectorized sub-region scaling within each head
185
+ if info.kind == 'wq_b':
186
+ qk_nope = info.qk_nope_head_dim
187
+ qk_head_dim = qk_nope + info.qk_rope_head_dim
188
+ W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
189
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β†’ √γ
190
+ W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
191
+ 1)) # q_pe β†’ Ξ³
192
+
193
+ elif info.kind == 'wkv_b':
194
+ qk_nope = info.qk_nope_head_dim
195
+ kv_stride = qk_nope + info.v_head_dim
196
+ W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
197
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β†’ √γ
198
+ # v rows: not touched (k_R shared rotary unchanged)
build/torch210-cxx11-cu130-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_38f9b8e_dirty
3
- ops = torch.ops._optimizer_38f9b8e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_38f9b8e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8d53b78_dirty
3
+ ops = torch.ops._optimizer_8d53b78_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8d53b78_dirty::{op_name}"
build/torch210-cxx11-cu130-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:783a161f2d28e4244226c9d6e59ac33f74f7a79aad17c06e8ce027dd6182e03c
3
  size 2004728
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45eef069a7caa85678cd1e05f0c60c5cfbc676dc93a1bcb31e55eb34730aa469
3
  size 2004728
build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py CHANGED
@@ -93,10 +93,7 @@ class CPUOffloadPool:
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
- cpu_flat = torch.empty(off,
97
- dtype=dtype,
98
- device="cpu",
99
- pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
@@ -140,8 +137,7 @@ class CPUOffloadPool:
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
- cpu_flat[off:off + n].copy_(local.reshape(-1),
144
- non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
@@ -159,8 +155,10 @@ class CPUOffloadPool:
159
  )
160
 
161
  if not self._logged:
162
- logger.info("[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
163
- offloaded_bytes / (1024**2))
 
 
164
 
165
  # ------------------------------------------------------------------
166
  def reload(self):
@@ -198,12 +196,11 @@ class CPUOffloadPool:
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
- local.reshape(-1).copy_(cpu_flat[off:off + n],
202
- non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
206
  if not self._logged:
207
- logger.info("[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)",
208
- reloaded_bytes / (1024**2))
209
- self._logged = True
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
97
  self._groups[dtype] = {
98
  "indices": indices,
99
  "offsets": offsets,
 
137
  for i, mgd_idx in enumerate(indices):
138
  local = self._local(self._managed[mgd_idx])
139
  off, n = offsets[i]
140
+ cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
 
141
 
142
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
143
 
 
155
  )
156
 
157
  if not self._logged:
158
+ logger.info(
159
+ "[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
160
+ offloaded_bytes / (1024**2),
161
+ )
162
 
163
  # ------------------------------------------------------------------
164
  def reload(self):
 
196
  for i, mgd_idx in enumerate(indices):
197
  local = self._local(self._managed[mgd_idx])
198
  off, n = offsets[i]
199
+ local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
 
200
 
201
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
202
 
203
  if not self._logged:
204
+ logger.info(
205
+ "[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)", reloaded_bytes / (1024**2)
206
+ )
build/torch210-cxx11-cu130-x86_64-linux/muon.py CHANGED
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
- qk_clip(p, scales_full, qk_clip_state.head_dim)
364
 
365
  def distributed_muon(
366
  self,
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
- qk_clip(p, scales_full, qk_clip_state.head_dim)
411
 
412
  if not dtensor_params:
413
  return
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
- raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
 
 
1054
  return super().state_dict()
1055
 
1056
  def load_state_dict(self, state_dict: dict) -> None:
1057
  if self.cpu_offload:
1058
- raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
  # Invalidate adamw.py's module-level tensor caches so that
 
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
+ qk_clip(p, scales_full, qk_clip_state)
364
 
365
  def distributed_muon(
366
  self,
 
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
+ qk_clip(p, scales_full, qk_clip_state)
411
 
412
  if not dtensor_params:
413
  return
 
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
+ raise RuntimeError(
1054
+ "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
1055
+ )
1056
  return super().state_dict()
1057
 
1058
  def load_state_dict(self, state_dict: dict) -> None:
1059
  if self.cpu_offload:
1060
+ raise RuntimeError(
1061
+ "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
1062
+ )
1063
  super().load_state_dict(state_dict)
1064
 
1065
  # Invalidate adamw.py's module-level tensor caches so that
build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py CHANGED
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
- LHS = np.array([
36
- [l, l**3, l**5, 1],
37
- [q, q**3, q**5, -1],
38
- [r, r**3, r**5, 1],
39
- [u, u**3, u**5, -1],
40
- ])
 
 
41
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
  if not np.all(np.isfinite([a, b, c, E])):
43
- raise ValueError(f"_optimal_quintic: non-finite solve result "
44
- f"a={a}, b={b}, c={c}, E={E}")
 
45
  q, r = np.sqrt(
46
- (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
47
- (10 * c))
48
  if not np.all(np.isfinite([q, r])):
49
- raise ValueError(
50
- f"_optimal_quintic: non-finite node update q={q}, r={r}")
51
  if abs(old_E - E) <= 1e-15:
52
  break
53
  else:
54
  raise RuntimeError(
55
- f"_optimal_quintic: did not converge after {max_iter} iterations")
 
56
  return float(a), float(b), float(c)
57
 
58
 
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
111
  # - Polar Express: analytically optimal per step, adapting to the shrinking
112
  # singular-value interval [l, u] as iterations progress; converges all
113
  # singular values to 1, producing the exact polar factor UV^T.
114
- _coeffs_list = _optimal_composition(l=1e-3,
115
- num_iters=10,
116
- safety_factor_eps=1e-2,
117
- cushion=0.02)
118
 
119
 
120
  # This code is adapted from:
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
148
 
149
  X = X / (X.norm() + 1e-7)
150
  hs = _coeffs_list[:steps] + list(
151
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
152
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
153
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  # Perform the NS iterations
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
183
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
184
 
185
  hs = _coeffs_list[:steps] + list(
186
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
187
  for a, b, c in hs:
188
  buf1 = torch.bmm(X, X.transpose(-2, -1))
189
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
 
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
+ LHS = np.array(
36
+ [
37
+ [l, l**3, l**5, 1],
38
+ [q, q**3, q**5, -1],
39
+ [r, r**3, r**5, 1],
40
+ [u, u**3, u**5, -1],
41
+ ]
42
+ )
43
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
44
  if not np.all(np.isfinite([a, b, c, E])):
45
+ raise ValueError(
46
+ f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
47
+ )
48
  q, r = np.sqrt(
49
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
50
+ )
51
  if not np.all(np.isfinite([q, r])):
52
+ raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
 
53
  if abs(old_E - E) <= 1e-15:
54
  break
55
  else:
56
  raise RuntimeError(
57
+ f"_optimal_quintic: did not converge after {max_iter} iterations"
58
+ )
59
  return float(a), float(b), float(c)
60
 
61
 
 
114
  # - Polar Express: analytically optimal per step, adapting to the shrinking
115
  # singular-value interval [l, u] as iterations progress; converges all
116
  # singular values to 1, producing the exact polar factor UV^T.
117
+ _coeffs_list = _optimal_composition(
118
+ l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
119
+ )
 
120
 
121
 
122
  # This code is adapted from:
 
150
 
151
  X = X / (X.norm() + 1e-7)
152
  hs = _coeffs_list[:steps] + list(
153
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
154
+ )
155
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
156
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
157
  # Perform the NS iterations
 
186
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
187
 
188
  hs = _coeffs_list[:steps] + list(
189
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
190
+ )
191
  for a, b, c in hs:
192
  buf1 = torch.bmm(X, X.transpose(-2, -1))
193
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py CHANGED
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
 
 
 
 
17
 
18
  Returns:
19
  (kind, layer_idx) or (None, -1) if not matched.
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
23
  'model.5.attn.wk.weight' -> ('wk', 5)
24
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
25
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
 
 
26
  'model.4.attn.v_proj.weight' -> (None, -1)
27
  """
28
  parts = normalize_fqn(name).split('.')
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
37
  layer_idx = int(part)
38
  break
39
 
40
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
41
  return kind, layer_idx
42
 
43
  return None, -1
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
46
  @dataclass
47
  class QKClipInfo:
48
  """Per-parameter dynamic info computed from config + runtime logits."""
49
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
50
  indices: list[int] # which heads to consider for clipping
51
- head_dim: int # from config
52
  threshold: float # from config
53
  logit: torch.Tensor | None
54
 
 
 
 
 
 
 
55
 
56
  def get_qk_clip_info(clip_config, n, qk_logits):
57
  """Extract QK clipping info for a named parameter.
58
 
59
  Args:
60
  clip_config: QK clipping configuration dict (or None).
 
 
61
  n: Parameter name string.
62
  qk_logits: Dict mapping layer indices to logit tensors (or None).
63
 
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
70
  head_dim = clip_config.get('head_dim')
71
  threshold = clip_config.get('threshold')
72
  kind, layer_idx = parse_qk_layer(n)
 
73
 
74
  logit, indices = None, []
75
  if qk_logits is not None and kind is not None:
76
  logit = qk_logits[layer_idx]
77
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
78
- indices = clip_config.get(indices_key, []) or []
79
-
80
  if isinstance(logit, DTensor):
81
  # In TP settings, qk_logits may be DTensor
82
  # We convert it to full tensor here for simplicity
83
  logit = logit.full_tensor()
84
 
85
- return QKClipInfo(
86
- kind=kind,
87
- indices=indices,
88
- head_dim=head_dim,
89
- threshold=threshold,
90
- logit=logit,
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  def compute_scales(p, qk_clip_state):
95
  """Compute per-head scaling factors for QK clipping.
96
 
97
- Returns scales tensor if any head exceeds threshold, else None.
 
98
  """
99
  kind = qk_clip_state.kind
100
  indices = qk_clip_state.indices
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
118
  if not head_scales:
119
  return None
120
 
121
- H_global = p.shape[0] // head_dim
 
 
 
 
 
 
122
  scales_full = torch.ones(H_global, device=p.data.device)
123
  for head_idx, scale in head_scales.items():
124
  scales_full[head_idx] = scale
125
  return scales_full
126
 
127
 
128
- def qk_clip(p, scales, head_dim):
129
- """Apply per-head scaling to a Q/K projection weight matrix."""
130
- if isinstance(p, torch.nn.Parameter):
131
- W = p.data.view(-1, head_dim, p.data.shape[1])
132
- W.mul_(scales.view(-1, 1, 1))
133
- else:
134
- W = p.view(-1, head_dim, p.shape[1])
135
- W.mul_(scales.view(-1, 1, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
+ and return (kind, layer_index).
17
+
18
+ Supported kinds:
19
+ MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
20
+ MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
21
 
22
  Returns:
23
  (kind, layer_idx) or (None, -1) if not matched.
 
27
  'model.5.attn.wk.weight' -> ('wk', 5)
28
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
29
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
30
+ 'model.1.attn.wq_b.weight' -> ('wq_b', 1)
31
+ 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
32
  'model.4.attn.v_proj.weight' -> (None, -1)
33
  """
34
  parts = normalize_fqn(name).split('.')
 
43
  layer_idx = int(part)
44
  break
45
 
46
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
47
  return kind, layer_idx
48
 
49
  return None, -1
 
52
  @dataclass
53
  class QKClipInfo:
54
  """Per-parameter dynamic info computed from config + runtime logits."""
55
+ kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
56
  indices: list[int] # which heads to consider for clipping
57
+ head_dim: int # from config (qk_head_dim for MLA wq_b)
58
  threshold: float # from config
59
  logit: torch.Tensor | None
60
 
61
+ # MLA-specific fields
62
+ is_mla: bool = False
63
+ qk_nope_head_dim: int = 0
64
+ qk_rope_head_dim: int = 0
65
+ v_head_dim: int = 0
66
+
67
 
68
  def get_qk_clip_info(clip_config, n, qk_logits):
69
  """Extract QK clipping info for a named parameter.
70
 
71
  Args:
72
  clip_config: QK clipping configuration dict (or None).
73
+ MHA/GQA keys: head_dim, threshold, q_indices, k_indices
74
+ MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
75
  n: Parameter name string.
76
  qk_logits: Dict mapping layer indices to logit tensors (or None).
77
 
 
84
  head_dim = clip_config.get('head_dim')
85
  threshold = clip_config.get('threshold')
86
  kind, layer_idx = parse_qk_layer(n)
87
+ is_mla = clip_config.get('is_mla', False)
88
 
89
  logit, indices = None, []
90
  if qk_logits is not None and kind is not None:
91
  logit = qk_logits[layer_idx]
 
 
 
92
  if isinstance(logit, DTensor):
93
  # In TP settings, qk_logits may be DTensor
94
  # We convert it to full tensor here for simplicity
95
  logit = logit.full_tensor()
96
 
97
+ if kind in ('wq_b', 'wq', 'q_proj'):
98
+ indices = clip_config.get('q_indices', []) or []
99
+ elif kind in ('wkv_b', 'wk', 'k_proj'):
100
+ indices = clip_config.get('k_indices', []) or []
101
+
102
+ if is_mla:
103
+ return QKClipInfo(
104
+ kind=kind,
105
+ indices=indices,
106
+ head_dim=head_dim,
107
+ threshold=threshold,
108
+ logit=logit,
109
+ is_mla=True,
110
+ qk_nope_head_dim=clip_config['qk_nope_head_dim'],
111
+ qk_rope_head_dim=clip_config['qk_rope_head_dim'],
112
+ v_head_dim=clip_config['v_head_dim'],
113
+ )
114
+ else:
115
+ return QKClipInfo(
116
+ kind=kind,
117
+ indices=indices,
118
+ head_dim=head_dim,
119
+ threshold=threshold,
120
+ logit=logit,
121
+ )
122
 
123
 
124
  def compute_scales(p, qk_clip_state):
125
  """Compute per-head scaling factors for QK clipping.
126
 
127
+ Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
128
+ For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
129
  """
130
  kind = qk_clip_state.kind
131
  indices = qk_clip_state.indices
 
149
  if not head_scales:
150
  return None
151
 
152
+ # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
153
+ if qk_clip_state.is_mla and kind == 'wkv_b':
154
+ effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
155
+ else:
156
+ effective_head_dim = head_dim
157
+
158
+ H_global = p.shape[0] // effective_head_dim
159
  scales_full = torch.ones(H_global, device=p.data.device)
160
  for head_idx, scale in head_scales.items():
161
  scales_full[head_idx] = scale
162
  return scales_full
163
 
164
 
165
+ def qk_clip(p, scales, info):
166
+ """Apply per-head scaling to a Q/K projection weight matrix.
167
+
168
+ Args:
169
+ p: Parameter (nn.Parameter or raw tensor).
170
+ scales: [n_heads] tensor, each element = √γ_h.
171
+ info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
172
+
173
+ MLA sub-region scaling per Algorithm 1 (MuonClip):
174
+ wq_b: q_nope rows β†’ √γ, q_pe rows β†’ Ξ³
175
+ wkv_b: k_nope rows β†’ √γ, v rows β†’ unchanged
176
+ """
177
+ W = p.data if isinstance(p, torch.nn.Parameter) else p
178
+
179
+ if not info.is_mla:
180
+ # MHA/GQA: uniform √γ applied to all rows in each head
181
+ W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
182
+ return
183
+
184
+ # MLA: vectorized sub-region scaling within each head
185
+ if info.kind == 'wq_b':
186
+ qk_nope = info.qk_nope_head_dim
187
+ qk_head_dim = qk_nope + info.qk_rope_head_dim
188
+ W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
189
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β†’ √γ
190
+ W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
191
+ 1)) # q_pe β†’ Ξ³
192
+
193
+ elif info.kind == 'wkv_b':
194
+ qk_nope = info.qk_nope_head_dim
195
+ kv_stride = qk_nope + info.v_head_dim
196
+ W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
197
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β†’ √γ
198
+ # v rows: not touched (k_R shared rotary unchanged)
build/torch210-cxx11-rocm70-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_38f9b8e_dirty
3
- ops = torch.ops._optimizer_38f9b8e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_38f9b8e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8d53b78_dirty
3
+ ops = torch.ops._optimizer_8d53b78_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8d53b78_dirty::{op_name}"
build/torch210-cxx11-rocm70-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8ec2fcc8a9dc8a1e4aa4e925eaee33613177873e474e8d627bf844dae80f5f8b
3
  size 1866400
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:055206c495ecade2fe4b5427db34f0a48152174e79808cbe1ce7d7ca86d32396
3
  size 1866400
build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py CHANGED
@@ -93,10 +93,7 @@ class CPUOffloadPool:
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
- cpu_flat = torch.empty(off,
97
- dtype=dtype,
98
- device="cpu",
99
- pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
@@ -140,8 +137,7 @@ class CPUOffloadPool:
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
- cpu_flat[off:off + n].copy_(local.reshape(-1),
144
- non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
@@ -159,8 +155,10 @@ class CPUOffloadPool:
159
  )
160
 
161
  if not self._logged:
162
- logger.info("[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
163
- offloaded_bytes / (1024**2))
 
 
164
 
165
  # ------------------------------------------------------------------
166
  def reload(self):
@@ -198,12 +196,11 @@ class CPUOffloadPool:
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
- local.reshape(-1).copy_(cpu_flat[off:off + n],
202
- non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
206
  if not self._logged:
207
- logger.info("[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)",
208
- reloaded_bytes / (1024**2))
209
- self._logged = True
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
97
  self._groups[dtype] = {
98
  "indices": indices,
99
  "offsets": offsets,
 
137
  for i, mgd_idx in enumerate(indices):
138
  local = self._local(self._managed[mgd_idx])
139
  off, n = offsets[i]
140
+ cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
 
141
 
142
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
143
 
 
155
  )
156
 
157
  if not self._logged:
158
+ logger.info(
159
+ "[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
160
+ offloaded_bytes / (1024**2),
161
+ )
162
 
163
  # ------------------------------------------------------------------
164
  def reload(self):
 
196
  for i, mgd_idx in enumerate(indices):
197
  local = self._local(self._managed[mgd_idx])
198
  off, n = offsets[i]
199
+ local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
 
200
 
201
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
202
 
203
  if not self._logged:
204
+ logger.info(
205
+ "[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)", reloaded_bytes / (1024**2)
206
+ )
build/torch210-cxx11-rocm70-x86_64-linux/muon.py CHANGED
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
- qk_clip(p, scales_full, qk_clip_state.head_dim)
364
 
365
  def distributed_muon(
366
  self,
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
- qk_clip(p, scales_full, qk_clip_state.head_dim)
411
 
412
  if not dtensor_params:
413
  return
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
- raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
 
 
1054
  return super().state_dict()
1055
 
1056
  def load_state_dict(self, state_dict: dict) -> None:
1057
  if self.cpu_offload:
1058
- raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
  # Invalidate adamw.py's module-level tensor caches so that
 
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
+ qk_clip(p, scales_full, qk_clip_state)
364
 
365
  def distributed_muon(
366
  self,
 
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
+ qk_clip(p, scales_full, qk_clip_state)
411
 
412
  if not dtensor_params:
413
  return
 
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
+ raise RuntimeError(
1054
+ "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
1055
+ )
1056
  return super().state_dict()
1057
 
1058
  def load_state_dict(self, state_dict: dict) -> None:
1059
  if self.cpu_offload:
1060
+ raise RuntimeError(
1061
+ "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
1062
+ )
1063
  super().load_state_dict(state_dict)
1064
 
1065
  # Invalidate adamw.py's module-level tensor caches so that
build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py CHANGED
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
- LHS = np.array([
36
- [l, l**3, l**5, 1],
37
- [q, q**3, q**5, -1],
38
- [r, r**3, r**5, 1],
39
- [u, u**3, u**5, -1],
40
- ])
 
 
41
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
  if not np.all(np.isfinite([a, b, c, E])):
43
- raise ValueError(f"_optimal_quintic: non-finite solve result "
44
- f"a={a}, b={b}, c={c}, E={E}")
 
45
  q, r = np.sqrt(
46
- (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
47
- (10 * c))
48
  if not np.all(np.isfinite([q, r])):
49
- raise ValueError(
50
- f"_optimal_quintic: non-finite node update q={q}, r={r}")
51
  if abs(old_E - E) <= 1e-15:
52
  break
53
  else:
54
  raise RuntimeError(
55
- f"_optimal_quintic: did not converge after {max_iter} iterations")
 
56
  return float(a), float(b), float(c)
57
 
58
 
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
111
  # - Polar Express: analytically optimal per step, adapting to the shrinking
112
  # singular-value interval [l, u] as iterations progress; converges all
113
  # singular values to 1, producing the exact polar factor UV^T.
114
- _coeffs_list = _optimal_composition(l=1e-3,
115
- num_iters=10,
116
- safety_factor_eps=1e-2,
117
- cushion=0.02)
118
 
119
 
120
  # This code is adapted from:
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
148
 
149
  X = X / (X.norm() + 1e-7)
150
  hs = _coeffs_list[:steps] + list(
151
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
152
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
153
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  # Perform the NS iterations
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
183
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
184
 
185
  hs = _coeffs_list[:steps] + list(
186
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
187
  for a, b, c in hs:
188
  buf1 = torch.bmm(X, X.transpose(-2, -1))
189
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
 
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
+ LHS = np.array(
36
+ [
37
+ [l, l**3, l**5, 1],
38
+ [q, q**3, q**5, -1],
39
+ [r, r**3, r**5, 1],
40
+ [u, u**3, u**5, -1],
41
+ ]
42
+ )
43
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
44
  if not np.all(np.isfinite([a, b, c, E])):
45
+ raise ValueError(
46
+ f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
47
+ )
48
  q, r = np.sqrt(
49
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
50
+ )
51
  if not np.all(np.isfinite([q, r])):
52
+ raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
 
53
  if abs(old_E - E) <= 1e-15:
54
  break
55
  else:
56
  raise RuntimeError(
57
+ f"_optimal_quintic: did not converge after {max_iter} iterations"
58
+ )
59
  return float(a), float(b), float(c)
60
 
61
 
 
114
  # - Polar Express: analytically optimal per step, adapting to the shrinking
115
  # singular-value interval [l, u] as iterations progress; converges all
116
  # singular values to 1, producing the exact polar factor UV^T.
117
+ _coeffs_list = _optimal_composition(
118
+ l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
119
+ )
 
120
 
121
 
122
  # This code is adapted from:
 
150
 
151
  X = X / (X.norm() + 1e-7)
152
  hs = _coeffs_list[:steps] + list(
153
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
154
+ )
155
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
156
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
157
  # Perform the NS iterations
 
186
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
187
 
188
  hs = _coeffs_list[:steps] + list(
189
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
190
+ )
191
  for a, b, c in hs:
192
  buf1 = torch.bmm(X, X.transpose(-2, -1))
193
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py CHANGED
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
 
 
 
 
17
 
18
  Returns:
19
  (kind, layer_idx) or (None, -1) if not matched.
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
23
  'model.5.attn.wk.weight' -> ('wk', 5)
24
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
25
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
 
 
26
  'model.4.attn.v_proj.weight' -> (None, -1)
27
  """
28
  parts = normalize_fqn(name).split('.')
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
37
  layer_idx = int(part)
38
  break
39
 
40
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
41
  return kind, layer_idx
42
 
43
  return None, -1
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
46
  @dataclass
47
  class QKClipInfo:
48
  """Per-parameter dynamic info computed from config + runtime logits."""
49
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
50
  indices: list[int] # which heads to consider for clipping
51
- head_dim: int # from config
52
  threshold: float # from config
53
  logit: torch.Tensor | None
54
 
 
 
 
 
 
 
55
 
56
  def get_qk_clip_info(clip_config, n, qk_logits):
57
  """Extract QK clipping info for a named parameter.
58
 
59
  Args:
60
  clip_config: QK clipping configuration dict (or None).
 
 
61
  n: Parameter name string.
62
  qk_logits: Dict mapping layer indices to logit tensors (or None).
63
 
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
70
  head_dim = clip_config.get('head_dim')
71
  threshold = clip_config.get('threshold')
72
  kind, layer_idx = parse_qk_layer(n)
 
73
 
74
  logit, indices = None, []
75
  if qk_logits is not None and kind is not None:
76
  logit = qk_logits[layer_idx]
77
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
78
- indices = clip_config.get(indices_key, []) or []
79
-
80
  if isinstance(logit, DTensor):
81
  # In TP settings, qk_logits may be DTensor
82
  # We convert it to full tensor here for simplicity
83
  logit = logit.full_tensor()
84
 
85
- return QKClipInfo(
86
- kind=kind,
87
- indices=indices,
88
- head_dim=head_dim,
89
- threshold=threshold,
90
- logit=logit,
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  def compute_scales(p, qk_clip_state):
95
  """Compute per-head scaling factors for QK clipping.
96
 
97
- Returns scales tensor if any head exceeds threshold, else None.
 
98
  """
99
  kind = qk_clip_state.kind
100
  indices = qk_clip_state.indices
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
118
  if not head_scales:
119
  return None
120
 
121
- H_global = p.shape[0] // head_dim
 
 
 
 
 
 
122
  scales_full = torch.ones(H_global, device=p.data.device)
123
  for head_idx, scale in head_scales.items():
124
  scales_full[head_idx] = scale
125
  return scales_full
126
 
127
 
128
- def qk_clip(p, scales, head_dim):
129
- """Apply per-head scaling to a Q/K projection weight matrix."""
130
- if isinstance(p, torch.nn.Parameter):
131
- W = p.data.view(-1, head_dim, p.data.shape[1])
132
- W.mul_(scales.view(-1, 1, 1))
133
- else:
134
- W = p.view(-1, head_dim, p.shape[1])
135
- W.mul_(scales.view(-1, 1, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
+ and return (kind, layer_index).
17
+
18
+ Supported kinds:
19
+ MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
20
+ MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
21
 
22
  Returns:
23
  (kind, layer_idx) or (None, -1) if not matched.
 
27
  'model.5.attn.wk.weight' -> ('wk', 5)
28
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
29
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
30
+ 'model.1.attn.wq_b.weight' -> ('wq_b', 1)
31
+ 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
32
  'model.4.attn.v_proj.weight' -> (None, -1)
33
  """
34
  parts = normalize_fqn(name).split('.')
 
43
  layer_idx = int(part)
44
  break
45
 
46
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
47
  return kind, layer_idx
48
 
49
  return None, -1
 
52
  @dataclass
53
  class QKClipInfo:
54
  """Per-parameter dynamic info computed from config + runtime logits."""
55
+ kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
56
  indices: list[int] # which heads to consider for clipping
57
+ head_dim: int # from config (qk_head_dim for MLA wq_b)
58
  threshold: float # from config
59
  logit: torch.Tensor | None
60
 
61
+ # MLA-specific fields
62
+ is_mla: bool = False
63
+ qk_nope_head_dim: int = 0
64
+ qk_rope_head_dim: int = 0
65
+ v_head_dim: int = 0
66
+
67
 
68
  def get_qk_clip_info(clip_config, n, qk_logits):
69
  """Extract QK clipping info for a named parameter.
70
 
71
  Args:
72
  clip_config: QK clipping configuration dict (or None).
73
+ MHA/GQA keys: head_dim, threshold, q_indices, k_indices
74
+ MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
75
  n: Parameter name string.
76
  qk_logits: Dict mapping layer indices to logit tensors (or None).
77
 
 
84
  head_dim = clip_config.get('head_dim')
85
  threshold = clip_config.get('threshold')
86
  kind, layer_idx = parse_qk_layer(n)
87
+ is_mla = clip_config.get('is_mla', False)
88
 
89
  logit, indices = None, []
90
  if qk_logits is not None and kind is not None:
91
  logit = qk_logits[layer_idx]
 
 
 
92
  if isinstance(logit, DTensor):
93
  # In TP settings, qk_logits may be DTensor
94
  # We convert it to full tensor here for simplicity
95
  logit = logit.full_tensor()
96
 
97
+ if kind in ('wq_b', 'wq', 'q_proj'):
98
+ indices = clip_config.get('q_indices', []) or []
99
+ elif kind in ('wkv_b', 'wk', 'k_proj'):
100
+ indices = clip_config.get('k_indices', []) or []
101
+
102
+ if is_mla:
103
+ return QKClipInfo(
104
+ kind=kind,
105
+ indices=indices,
106
+ head_dim=head_dim,
107
+ threshold=threshold,
108
+ logit=logit,
109
+ is_mla=True,
110
+ qk_nope_head_dim=clip_config['qk_nope_head_dim'],
111
+ qk_rope_head_dim=clip_config['qk_rope_head_dim'],
112
+ v_head_dim=clip_config['v_head_dim'],
113
+ )
114
+ else:
115
+ return QKClipInfo(
116
+ kind=kind,
117
+ indices=indices,
118
+ head_dim=head_dim,
119
+ threshold=threshold,
120
+ logit=logit,
121
+ )
122
 
123
 
124
  def compute_scales(p, qk_clip_state):
125
  """Compute per-head scaling factors for QK clipping.
126
 
127
+ Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
128
+ For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
129
  """
130
  kind = qk_clip_state.kind
131
  indices = qk_clip_state.indices
 
149
  if not head_scales:
150
  return None
151
 
152
+ # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
153
+ if qk_clip_state.is_mla and kind == 'wkv_b':
154
+ effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
155
+ else:
156
+ effective_head_dim = head_dim
157
+
158
+ H_global = p.shape[0] // effective_head_dim
159
  scales_full = torch.ones(H_global, device=p.data.device)
160
  for head_idx, scale in head_scales.items():
161
  scales_full[head_idx] = scale
162
  return scales_full
163
 
164
 
165
+ def qk_clip(p, scales, info):
166
+ """Apply per-head scaling to a Q/K projection weight matrix.
167
+
168
+ Args:
169
+ p: Parameter (nn.Parameter or raw tensor).
170
+ scales: [n_heads] tensor, each element = √γ_h.
171
+ info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
172
+
173
+ MLA sub-region scaling per Algorithm 1 (MuonClip):
174
+ wq_b: q_nope rows β†’ √γ, q_pe rows β†’ Ξ³
175
+ wkv_b: k_nope rows β†’ √γ, v rows β†’ unchanged
176
+ """
177
+ W = p.data if isinstance(p, torch.nn.Parameter) else p
178
+
179
+ if not info.is_mla:
180
+ # MHA/GQA: uniform √γ applied to all rows in each head
181
+ W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
182
+ return
183
+
184
+ # MLA: vectorized sub-region scaling within each head
185
+ if info.kind == 'wq_b':
186
+ qk_nope = info.qk_nope_head_dim
187
+ qk_head_dim = qk_nope + info.qk_rope_head_dim
188
+ W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
189
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β†’ √γ
190
+ W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
191
+ 1)) # q_pe β†’ Ξ³
192
+
193
+ elif info.kind == 'wkv_b':
194
+ qk_nope = info.qk_nope_head_dim
195
+ kv_stride = qk_nope + info.v_head_dim
196
+ W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
197
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β†’ √γ
198
+ # v rows: not touched (k_R shared rotary unchanged)
build/torch210-cxx11-rocm71-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_38f9b8e_dirty
3
- ops = torch.ops._optimizer_38f9b8e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_38f9b8e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8d53b78_dirty
3
+ ops = torch.ops._optimizer_8d53b78_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8d53b78_dirty::{op_name}"
build/torch210-cxx11-rocm71-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:18373b2e448071735ce724008122f179dd814986925c9cf0fc03f32201b2b1fa
3
  size 1866112
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:315ff09ffa88ec806cb8abe49edb2ca6951e9ac34be3d3e10f159093f9576ee0
3
  size 1866112
build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py CHANGED
@@ -93,10 +93,7 @@ class CPUOffloadPool:
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
- cpu_flat = torch.empty(off,
97
- dtype=dtype,
98
- device="cpu",
99
- pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
@@ -140,8 +137,7 @@ class CPUOffloadPool:
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
- cpu_flat[off:off + n].copy_(local.reshape(-1),
144
- non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
@@ -159,8 +155,10 @@ class CPUOffloadPool:
159
  )
160
 
161
  if not self._logged:
162
- logger.info("[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
163
- offloaded_bytes / (1024**2))
 
 
164
 
165
  # ------------------------------------------------------------------
166
  def reload(self):
@@ -198,12 +196,11 @@ class CPUOffloadPool:
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
- local.reshape(-1).copy_(cpu_flat[off:off + n],
202
- non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
206
  if not self._logged:
207
- logger.info("[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)",
208
- reloaded_bytes / (1024**2))
209
- self._logged = True
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
97
  self._groups[dtype] = {
98
  "indices": indices,
99
  "offsets": offsets,
 
137
  for i, mgd_idx in enumerate(indices):
138
  local = self._local(self._managed[mgd_idx])
139
  off, n = offsets[i]
140
+ cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
 
141
 
142
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
143
 
 
155
  )
156
 
157
  if not self._logged:
158
+ logger.info(
159
+ "[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
160
+ offloaded_bytes / (1024**2),
161
+ )
162
 
163
  # ------------------------------------------------------------------
164
  def reload(self):
 
196
  for i, mgd_idx in enumerate(indices):
197
  local = self._local(self._managed[mgd_idx])
198
  off, n = offsets[i]
199
+ local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
 
200
 
201
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
202
 
203
  if not self._logged:
204
+ logger.info(
205
+ "[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)", reloaded_bytes / (1024**2)
206
+ )
build/torch210-cxx11-rocm71-x86_64-linux/muon.py CHANGED
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
- qk_clip(p, scales_full, qk_clip_state.head_dim)
364
 
365
  def distributed_muon(
366
  self,
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
- qk_clip(p, scales_full, qk_clip_state.head_dim)
411
 
412
  if not dtensor_params:
413
  return
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
- raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
 
 
1054
  return super().state_dict()
1055
 
1056
  def load_state_dict(self, state_dict: dict) -> None:
1057
  if self.cpu_offload:
1058
- raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
  # Invalidate adamw.py's module-level tensor caches so that
 
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
+ qk_clip(p, scales_full, qk_clip_state)
364
 
365
  def distributed_muon(
366
  self,
 
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
+ qk_clip(p, scales_full, qk_clip_state)
411
 
412
  if not dtensor_params:
413
  return
 
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
+ raise RuntimeError(
1054
+ "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
1055
+ )
1056
  return super().state_dict()
1057
 
1058
  def load_state_dict(self, state_dict: dict) -> None:
1059
  if self.cpu_offload:
1060
+ raise RuntimeError(
1061
+ "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
1062
+ )
1063
  super().load_state_dict(state_dict)
1064
 
1065
  # Invalidate adamw.py's module-level tensor caches so that
build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py CHANGED
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
- LHS = np.array([
36
- [l, l**3, l**5, 1],
37
- [q, q**3, q**5, -1],
38
- [r, r**3, r**5, 1],
39
- [u, u**3, u**5, -1],
40
- ])
 
 
41
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
  if not np.all(np.isfinite([a, b, c, E])):
43
- raise ValueError(f"_optimal_quintic: non-finite solve result "
44
- f"a={a}, b={b}, c={c}, E={E}")
 
45
  q, r = np.sqrt(
46
- (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
47
- (10 * c))
48
  if not np.all(np.isfinite([q, r])):
49
- raise ValueError(
50
- f"_optimal_quintic: non-finite node update q={q}, r={r}")
51
  if abs(old_E - E) <= 1e-15:
52
  break
53
  else:
54
  raise RuntimeError(
55
- f"_optimal_quintic: did not converge after {max_iter} iterations")
 
56
  return float(a), float(b), float(c)
57
 
58
 
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
111
  # - Polar Express: analytically optimal per step, adapting to the shrinking
112
  # singular-value interval [l, u] as iterations progress; converges all
113
  # singular values to 1, producing the exact polar factor UV^T.
114
- _coeffs_list = _optimal_composition(l=1e-3,
115
- num_iters=10,
116
- safety_factor_eps=1e-2,
117
- cushion=0.02)
118
 
119
 
120
  # This code is adapted from:
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
148
 
149
  X = X / (X.norm() + 1e-7)
150
  hs = _coeffs_list[:steps] + list(
151
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
152
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
153
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  # Perform the NS iterations
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
183
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
184
 
185
  hs = _coeffs_list[:steps] + list(
186
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
187
  for a, b, c in hs:
188
  buf1 = torch.bmm(X, X.transpose(-2, -1))
189
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
 
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
+ LHS = np.array(
36
+ [
37
+ [l, l**3, l**5, 1],
38
+ [q, q**3, q**5, -1],
39
+ [r, r**3, r**5, 1],
40
+ [u, u**3, u**5, -1],
41
+ ]
42
+ )
43
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
44
  if not np.all(np.isfinite([a, b, c, E])):
45
+ raise ValueError(
46
+ f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
47
+ )
48
  q, r = np.sqrt(
49
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
50
+ )
51
  if not np.all(np.isfinite([q, r])):
52
+ raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
 
53
  if abs(old_E - E) <= 1e-15:
54
  break
55
  else:
56
  raise RuntimeError(
57
+ f"_optimal_quintic: did not converge after {max_iter} iterations"
58
+ )
59
  return float(a), float(b), float(c)
60
 
61
 
 
114
  # - Polar Express: analytically optimal per step, adapting to the shrinking
115
  # singular-value interval [l, u] as iterations progress; converges all
116
  # singular values to 1, producing the exact polar factor UV^T.
117
+ _coeffs_list = _optimal_composition(
118
+ l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
119
+ )
 
120
 
121
 
122
  # This code is adapted from:
 
150
 
151
  X = X / (X.norm() + 1e-7)
152
  hs = _coeffs_list[:steps] + list(
153
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
154
+ )
155
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
156
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
157
  # Perform the NS iterations
 
186
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
187
 
188
  hs = _coeffs_list[:steps] + list(
189
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
190
+ )
191
  for a, b, c in hs:
192
  buf1 = torch.bmm(X, X.transpose(-2, -1))
193
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py CHANGED
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
 
 
 
 
17
 
18
  Returns:
19
  (kind, layer_idx) or (None, -1) if not matched.
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
23
  'model.5.attn.wk.weight' -> ('wk', 5)
24
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
25
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
 
 
26
  'model.4.attn.v_proj.weight' -> (None, -1)
27
  """
28
  parts = normalize_fqn(name).split('.')
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
37
  layer_idx = int(part)
38
  break
39
 
40
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
41
  return kind, layer_idx
42
 
43
  return None, -1
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
46
  @dataclass
47
  class QKClipInfo:
48
  """Per-parameter dynamic info computed from config + runtime logits."""
49
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
50
  indices: list[int] # which heads to consider for clipping
51
- head_dim: int # from config
52
  threshold: float # from config
53
  logit: torch.Tensor | None
54
 
 
 
 
 
 
 
55
 
56
  def get_qk_clip_info(clip_config, n, qk_logits):
57
  """Extract QK clipping info for a named parameter.
58
 
59
  Args:
60
  clip_config: QK clipping configuration dict (or None).
 
 
61
  n: Parameter name string.
62
  qk_logits: Dict mapping layer indices to logit tensors (or None).
63
 
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
70
  head_dim = clip_config.get('head_dim')
71
  threshold = clip_config.get('threshold')
72
  kind, layer_idx = parse_qk_layer(n)
 
73
 
74
  logit, indices = None, []
75
  if qk_logits is not None and kind is not None:
76
  logit = qk_logits[layer_idx]
77
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
78
- indices = clip_config.get(indices_key, []) or []
79
-
80
  if isinstance(logit, DTensor):
81
  # In TP settings, qk_logits may be DTensor
82
  # We convert it to full tensor here for simplicity
83
  logit = logit.full_tensor()
84
 
85
- return QKClipInfo(
86
- kind=kind,
87
- indices=indices,
88
- head_dim=head_dim,
89
- threshold=threshold,
90
- logit=logit,
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  def compute_scales(p, qk_clip_state):
95
  """Compute per-head scaling factors for QK clipping.
96
 
97
- Returns scales tensor if any head exceeds threshold, else None.
 
98
  """
99
  kind = qk_clip_state.kind
100
  indices = qk_clip_state.indices
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
118
  if not head_scales:
119
  return None
120
 
121
- H_global = p.shape[0] // head_dim
 
 
 
 
 
 
122
  scales_full = torch.ones(H_global, device=p.data.device)
123
  for head_idx, scale in head_scales.items():
124
  scales_full[head_idx] = scale
125
  return scales_full
126
 
127
 
128
- def qk_clip(p, scales, head_dim):
129
- """Apply per-head scaling to a Q/K projection weight matrix."""
130
- if isinstance(p, torch.nn.Parameter):
131
- W = p.data.view(-1, head_dim, p.data.shape[1])
132
- W.mul_(scales.view(-1, 1, 1))
133
- else:
134
- W = p.view(-1, head_dim, p.shape[1])
135
- W.mul_(scales.view(-1, 1, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
+ and return (kind, layer_index).
17
+
18
+ Supported kinds:
19
+ MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
20
+ MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
21
 
22
  Returns:
23
  (kind, layer_idx) or (None, -1) if not matched.
 
27
  'model.5.attn.wk.weight' -> ('wk', 5)
28
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
29
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
30
+ 'model.1.attn.wq_b.weight' -> ('wq_b', 1)
31
+ 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
32
  'model.4.attn.v_proj.weight' -> (None, -1)
33
  """
34
  parts = normalize_fqn(name).split('.')
 
43
  layer_idx = int(part)
44
  break
45
 
46
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
47
  return kind, layer_idx
48
 
49
  return None, -1
 
52
  @dataclass
53
  class QKClipInfo:
54
  """Per-parameter dynamic info computed from config + runtime logits."""
55
+ kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
56
  indices: list[int] # which heads to consider for clipping
57
+ head_dim: int # from config (qk_head_dim for MLA wq_b)
58
  threshold: float # from config
59
  logit: torch.Tensor | None
60
 
61
+ # MLA-specific fields
62
+ is_mla: bool = False
63
+ qk_nope_head_dim: int = 0
64
+ qk_rope_head_dim: int = 0
65
+ v_head_dim: int = 0
66
+
67
 
68
  def get_qk_clip_info(clip_config, n, qk_logits):
69
  """Extract QK clipping info for a named parameter.
70
 
71
  Args:
72
  clip_config: QK clipping configuration dict (or None).
73
+ MHA/GQA keys: head_dim, threshold, q_indices, k_indices
74
+ MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
75
  n: Parameter name string.
76
  qk_logits: Dict mapping layer indices to logit tensors (or None).
77
 
 
84
  head_dim = clip_config.get('head_dim')
85
  threshold = clip_config.get('threshold')
86
  kind, layer_idx = parse_qk_layer(n)
87
+ is_mla = clip_config.get('is_mla', False)
88
 
89
  logit, indices = None, []
90
  if qk_logits is not None and kind is not None:
91
  logit = qk_logits[layer_idx]
 
 
 
92
  if isinstance(logit, DTensor):
93
  # In TP settings, qk_logits may be DTensor
94
  # We convert it to full tensor here for simplicity
95
  logit = logit.full_tensor()
96
 
97
+ if kind in ('wq_b', 'wq', 'q_proj'):
98
+ indices = clip_config.get('q_indices', []) or []
99
+ elif kind in ('wkv_b', 'wk', 'k_proj'):
100
+ indices = clip_config.get('k_indices', []) or []
101
+
102
+ if is_mla:
103
+ return QKClipInfo(
104
+ kind=kind,
105
+ indices=indices,
106
+ head_dim=head_dim,
107
+ threshold=threshold,
108
+ logit=logit,
109
+ is_mla=True,
110
+ qk_nope_head_dim=clip_config['qk_nope_head_dim'],
111
+ qk_rope_head_dim=clip_config['qk_rope_head_dim'],
112
+ v_head_dim=clip_config['v_head_dim'],
113
+ )
114
+ else:
115
+ return QKClipInfo(
116
+ kind=kind,
117
+ indices=indices,
118
+ head_dim=head_dim,
119
+ threshold=threshold,
120
+ logit=logit,
121
+ )
122
 
123
 
124
  def compute_scales(p, qk_clip_state):
125
  """Compute per-head scaling factors for QK clipping.
126
 
127
+ Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
128
+ For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
129
  """
130
  kind = qk_clip_state.kind
131
  indices = qk_clip_state.indices
 
149
  if not head_scales:
150
  return None
151
 
152
+ # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
153
+ if qk_clip_state.is_mla and kind == 'wkv_b':
154
+ effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
155
+ else:
156
+ effective_head_dim = head_dim
157
+
158
+ H_global = p.shape[0] // effective_head_dim
159
  scales_full = torch.ones(H_global, device=p.data.device)
160
  for head_idx, scale in head_scales.items():
161
  scales_full[head_idx] = scale
162
  return scales_full
163
 
164
 
165
+ def qk_clip(p, scales, info):
166
+ """Apply per-head scaling to a Q/K projection weight matrix.
167
+
168
+ Args:
169
+ p: Parameter (nn.Parameter or raw tensor).
170
+ scales: [n_heads] tensor, each element = √γ_h.
171
+ info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
172
+
173
+ MLA sub-region scaling per Algorithm 1 (MuonClip):
174
+ wq_b: q_nope rows β†’ √γ, q_pe rows β†’ Ξ³
175
+ wkv_b: k_nope rows β†’ √γ, v rows β†’ unchanged
176
+ """
177
+ W = p.data if isinstance(p, torch.nn.Parameter) else p
178
+
179
+ if not info.is_mla:
180
+ # MHA/GQA: uniform √γ applied to all rows in each head
181
+ W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
182
+ return
183
+
184
+ # MLA: vectorized sub-region scaling within each head
185
+ if info.kind == 'wq_b':
186
+ qk_nope = info.qk_nope_head_dim
187
+ qk_head_dim = qk_nope + info.qk_rope_head_dim
188
+ W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
189
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β†’ √γ
190
+ W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
191
+ 1)) # q_pe β†’ Ξ³
192
+
193
+ elif info.kind == 'wkv_b':
194
+ qk_nope = info.qk_nope_head_dim
195
+ kv_stride = qk_nope + info.v_head_dim
196
+ W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
197
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β†’ √γ
198
+ # v rows: not touched (k_R shared rotary unchanged)
build/torch28-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_38f9b8e_dirty
3
- ops = torch.ops._optimizer_38f9b8e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_38f9b8e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8d53b78_dirty
3
+ ops = torch.ops._optimizer_8d53b78_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8d53b78_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d2db9c7fb764a1fae1872779bc9ffac2aff18d14a238111d6b8b53b7d3dfa0d3
3
  size 1936664
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9a7c1beffbad405ef7d6f46f44cf9c6671d119e04a340b54c8f4c8f9d699caf
3
  size 1936664
build/torch28-cxx11-cu126-x86_64-linux/cpu_offload.py CHANGED
@@ -93,10 +93,7 @@ class CPUOffloadPool:
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
- cpu_flat = torch.empty(off,
97
- dtype=dtype,
98
- device="cpu",
99
- pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
@@ -140,8 +137,7 @@ class CPUOffloadPool:
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
- cpu_flat[off:off + n].copy_(local.reshape(-1),
144
- non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
@@ -159,8 +155,10 @@ class CPUOffloadPool:
159
  )
160
 
161
  if not self._logged:
162
- logger.info("[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
163
- offloaded_bytes / (1024**2))
 
 
164
 
165
  # ------------------------------------------------------------------
166
  def reload(self):
@@ -198,12 +196,11 @@ class CPUOffloadPool:
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
- local.reshape(-1).copy_(cpu_flat[off:off + n],
202
- non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
206
  if not self._logged:
207
- logger.info("[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)",
208
- reloaded_bytes / (1024**2))
209
- self._logged = True
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
97
  self._groups[dtype] = {
98
  "indices": indices,
99
  "offsets": offsets,
 
137
  for i, mgd_idx in enumerate(indices):
138
  local = self._local(self._managed[mgd_idx])
139
  off, n = offsets[i]
140
+ cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
 
141
 
142
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
143
 
 
155
  )
156
 
157
  if not self._logged:
158
+ logger.info(
159
+ "[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
160
+ offloaded_bytes / (1024**2),
161
+ )
162
 
163
  # ------------------------------------------------------------------
164
  def reload(self):
 
196
  for i, mgd_idx in enumerate(indices):
197
  local = self._local(self._managed[mgd_idx])
198
  off, n = offsets[i]
199
+ local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
 
200
 
201
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
202
 
203
  if not self._logged:
204
+ logger.info(
205
+ "[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)", reloaded_bytes / (1024**2)
206
+ )
build/torch28-cxx11-cu126-x86_64-linux/muon.py CHANGED
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
- qk_clip(p, scales_full, qk_clip_state.head_dim)
364
 
365
  def distributed_muon(
366
  self,
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
- qk_clip(p, scales_full, qk_clip_state.head_dim)
411
 
412
  if not dtensor_params:
413
  return
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
- raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
 
 
1054
  return super().state_dict()
1055
 
1056
  def load_state_dict(self, state_dict: dict) -> None:
1057
  if self.cpu_offload:
1058
- raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
  # Invalidate adamw.py's module-level tensor caches so that
 
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
+ qk_clip(p, scales_full, qk_clip_state)
364
 
365
  def distributed_muon(
366
  self,
 
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
+ qk_clip(p, scales_full, qk_clip_state)
411
 
412
  if not dtensor_params:
413
  return
 
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
+ raise RuntimeError(
1054
+ "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
1055
+ )
1056
  return super().state_dict()
1057
 
1058
  def load_state_dict(self, state_dict: dict) -> None:
1059
  if self.cpu_offload:
1060
+ raise RuntimeError(
1061
+ "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
1062
+ )
1063
  super().load_state_dict(state_dict)
1064
 
1065
  # Invalidate adamw.py's module-level tensor caches so that
build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py CHANGED
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
- LHS = np.array([
36
- [l, l**3, l**5, 1],
37
- [q, q**3, q**5, -1],
38
- [r, r**3, r**5, 1],
39
- [u, u**3, u**5, -1],
40
- ])
 
 
41
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
  if not np.all(np.isfinite([a, b, c, E])):
43
- raise ValueError(f"_optimal_quintic: non-finite solve result "
44
- f"a={a}, b={b}, c={c}, E={E}")
 
45
  q, r = np.sqrt(
46
- (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
47
- (10 * c))
48
  if not np.all(np.isfinite([q, r])):
49
- raise ValueError(
50
- f"_optimal_quintic: non-finite node update q={q}, r={r}")
51
  if abs(old_E - E) <= 1e-15:
52
  break
53
  else:
54
  raise RuntimeError(
55
- f"_optimal_quintic: did not converge after {max_iter} iterations")
 
56
  return float(a), float(b), float(c)
57
 
58
 
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
111
  # - Polar Express: analytically optimal per step, adapting to the shrinking
112
  # singular-value interval [l, u] as iterations progress; converges all
113
  # singular values to 1, producing the exact polar factor UV^T.
114
- _coeffs_list = _optimal_composition(l=1e-3,
115
- num_iters=10,
116
- safety_factor_eps=1e-2,
117
- cushion=0.02)
118
 
119
 
120
  # This code is adapted from:
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
148
 
149
  X = X / (X.norm() + 1e-7)
150
  hs = _coeffs_list[:steps] + list(
151
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
152
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
153
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  # Perform the NS iterations
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
183
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
184
 
185
  hs = _coeffs_list[:steps] + list(
186
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
187
  for a, b, c in hs:
188
  buf1 = torch.bmm(X, X.transpose(-2, -1))
189
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
 
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
+ LHS = np.array(
36
+ [
37
+ [l, l**3, l**5, 1],
38
+ [q, q**3, q**5, -1],
39
+ [r, r**3, r**5, 1],
40
+ [u, u**3, u**5, -1],
41
+ ]
42
+ )
43
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
44
  if not np.all(np.isfinite([a, b, c, E])):
45
+ raise ValueError(
46
+ f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
47
+ )
48
  q, r = np.sqrt(
49
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
50
+ )
51
  if not np.all(np.isfinite([q, r])):
52
+ raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
 
53
  if abs(old_E - E) <= 1e-15:
54
  break
55
  else:
56
  raise RuntimeError(
57
+ f"_optimal_quintic: did not converge after {max_iter} iterations"
58
+ )
59
  return float(a), float(b), float(c)
60
 
61
 
 
114
  # - Polar Express: analytically optimal per step, adapting to the shrinking
115
  # singular-value interval [l, u] as iterations progress; converges all
116
  # singular values to 1, producing the exact polar factor UV^T.
117
+ _coeffs_list = _optimal_composition(
118
+ l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
119
+ )
 
120
 
121
 
122
  # This code is adapted from:
 
150
 
151
  X = X / (X.norm() + 1e-7)
152
  hs = _coeffs_list[:steps] + list(
153
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
154
+ )
155
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
156
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
157
  # Perform the NS iterations
 
186
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
187
 
188
  hs = _coeffs_list[:steps] + list(
189
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
190
+ )
191
  for a, b, c in hs:
192
  buf1 = torch.bmm(X, X.transpose(-2, -1))
193
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py CHANGED
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
 
 
 
 
17
 
18
  Returns:
19
  (kind, layer_idx) or (None, -1) if not matched.
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
23
  'model.5.attn.wk.weight' -> ('wk', 5)
24
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
25
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
 
 
26
  'model.4.attn.v_proj.weight' -> (None, -1)
27
  """
28
  parts = normalize_fqn(name).split('.')
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
37
  layer_idx = int(part)
38
  break
39
 
40
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
41
  return kind, layer_idx
42
 
43
  return None, -1
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
46
  @dataclass
47
  class QKClipInfo:
48
  """Per-parameter dynamic info computed from config + runtime logits."""
49
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
50
  indices: list[int] # which heads to consider for clipping
51
- head_dim: int # from config
52
  threshold: float # from config
53
  logit: torch.Tensor | None
54
 
 
 
 
 
 
 
55
 
56
  def get_qk_clip_info(clip_config, n, qk_logits):
57
  """Extract QK clipping info for a named parameter.
58
 
59
  Args:
60
  clip_config: QK clipping configuration dict (or None).
 
 
61
  n: Parameter name string.
62
  qk_logits: Dict mapping layer indices to logit tensors (or None).
63
 
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
70
  head_dim = clip_config.get('head_dim')
71
  threshold = clip_config.get('threshold')
72
  kind, layer_idx = parse_qk_layer(n)
 
73
 
74
  logit, indices = None, []
75
  if qk_logits is not None and kind is not None:
76
  logit = qk_logits[layer_idx]
77
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
78
- indices = clip_config.get(indices_key, []) or []
79
-
80
  if isinstance(logit, DTensor):
81
  # In TP settings, qk_logits may be DTensor
82
  # We convert it to full tensor here for simplicity
83
  logit = logit.full_tensor()
84
 
85
- return QKClipInfo(
86
- kind=kind,
87
- indices=indices,
88
- head_dim=head_dim,
89
- threshold=threshold,
90
- logit=logit,
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  def compute_scales(p, qk_clip_state):
95
  """Compute per-head scaling factors for QK clipping.
96
 
97
- Returns scales tensor if any head exceeds threshold, else None.
 
98
  """
99
  kind = qk_clip_state.kind
100
  indices = qk_clip_state.indices
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
118
  if not head_scales:
119
  return None
120
 
121
- H_global = p.shape[0] // head_dim
 
 
 
 
 
 
122
  scales_full = torch.ones(H_global, device=p.data.device)
123
  for head_idx, scale in head_scales.items():
124
  scales_full[head_idx] = scale
125
  return scales_full
126
 
127
 
128
- def qk_clip(p, scales, head_dim):
129
- """Apply per-head scaling to a Q/K projection weight matrix."""
130
- if isinstance(p, torch.nn.Parameter):
131
- W = p.data.view(-1, head_dim, p.data.shape[1])
132
- W.mul_(scales.view(-1, 1, 1))
133
- else:
134
- W = p.view(-1, head_dim, p.shape[1])
135
- W.mul_(scales.view(-1, 1, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
+ and return (kind, layer_index).
17
+
18
+ Supported kinds:
19
+ MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
20
+ MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
21
 
22
  Returns:
23
  (kind, layer_idx) or (None, -1) if not matched.
 
27
  'model.5.attn.wk.weight' -> ('wk', 5)
28
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
29
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
30
+ 'model.1.attn.wq_b.weight' -> ('wq_b', 1)
31
+ 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
32
  'model.4.attn.v_proj.weight' -> (None, -1)
33
  """
34
  parts = normalize_fqn(name).split('.')
 
43
  layer_idx = int(part)
44
  break
45
 
46
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
47
  return kind, layer_idx
48
 
49
  return None, -1
 
52
  @dataclass
53
  class QKClipInfo:
54
  """Per-parameter dynamic info computed from config + runtime logits."""
55
+ kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
56
  indices: list[int] # which heads to consider for clipping
57
+ head_dim: int # from config (qk_head_dim for MLA wq_b)
58
  threshold: float # from config
59
  logit: torch.Tensor | None
60
 
61
+ # MLA-specific fields
62
+ is_mla: bool = False
63
+ qk_nope_head_dim: int = 0
64
+ qk_rope_head_dim: int = 0
65
+ v_head_dim: int = 0
66
+
67
 
68
  def get_qk_clip_info(clip_config, n, qk_logits):
69
  """Extract QK clipping info for a named parameter.
70
 
71
  Args:
72
  clip_config: QK clipping configuration dict (or None).
73
+ MHA/GQA keys: head_dim, threshold, q_indices, k_indices
74
+ MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
75
  n: Parameter name string.
76
  qk_logits: Dict mapping layer indices to logit tensors (or None).
77
 
 
84
  head_dim = clip_config.get('head_dim')
85
  threshold = clip_config.get('threshold')
86
  kind, layer_idx = parse_qk_layer(n)
87
+ is_mla = clip_config.get('is_mla', False)
88
 
89
  logit, indices = None, []
90
  if qk_logits is not None and kind is not None:
91
  logit = qk_logits[layer_idx]
 
 
 
92
  if isinstance(logit, DTensor):
93
  # In TP settings, qk_logits may be DTensor
94
  # We convert it to full tensor here for simplicity
95
  logit = logit.full_tensor()
96
 
97
+ if kind in ('wq_b', 'wq', 'q_proj'):
98
+ indices = clip_config.get('q_indices', []) or []
99
+ elif kind in ('wkv_b', 'wk', 'k_proj'):
100
+ indices = clip_config.get('k_indices', []) or []
101
+
102
+ if is_mla:
103
+ return QKClipInfo(
104
+ kind=kind,
105
+ indices=indices,
106
+ head_dim=head_dim,
107
+ threshold=threshold,
108
+ logit=logit,
109
+ is_mla=True,
110
+ qk_nope_head_dim=clip_config['qk_nope_head_dim'],
111
+ qk_rope_head_dim=clip_config['qk_rope_head_dim'],
112
+ v_head_dim=clip_config['v_head_dim'],
113
+ )
114
+ else:
115
+ return QKClipInfo(
116
+ kind=kind,
117
+ indices=indices,
118
+ head_dim=head_dim,
119
+ threshold=threshold,
120
+ logit=logit,
121
+ )
122
 
123
 
124
  def compute_scales(p, qk_clip_state):
125
  """Compute per-head scaling factors for QK clipping.
126
 
127
+ Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
128
+ For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
129
  """
130
  kind = qk_clip_state.kind
131
  indices = qk_clip_state.indices
 
149
  if not head_scales:
150
  return None
151
 
152
+ # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
153
+ if qk_clip_state.is_mla and kind == 'wkv_b':
154
+ effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
155
+ else:
156
+ effective_head_dim = head_dim
157
+
158
+ H_global = p.shape[0] // effective_head_dim
159
  scales_full = torch.ones(H_global, device=p.data.device)
160
  for head_idx, scale in head_scales.items():
161
  scales_full[head_idx] = scale
162
  return scales_full
163
 
164
 
165
+ def qk_clip(p, scales, info):
166
+ """Apply per-head scaling to a Q/K projection weight matrix.
167
+
168
+ Args:
169
+ p: Parameter (nn.Parameter or raw tensor).
170
+ scales: [n_heads] tensor, each element = √γ_h.
171
+ info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
172
+
173
+ MLA sub-region scaling per Algorithm 1 (MuonClip):
174
+ wq_b: q_nope rows β†’ √γ, q_pe rows β†’ Ξ³
175
+ wkv_b: k_nope rows β†’ √γ, v rows β†’ unchanged
176
+ """
177
+ W = p.data if isinstance(p, torch.nn.Parameter) else p
178
+
179
+ if not info.is_mla:
180
+ # MHA/GQA: uniform √γ applied to all rows in each head
181
+ W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
182
+ return
183
+
184
+ # MLA: vectorized sub-region scaling within each head
185
+ if info.kind == 'wq_b':
186
+ qk_nope = info.qk_nope_head_dim
187
+ qk_head_dim = qk_nope + info.qk_rope_head_dim
188
+ W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
189
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β†’ √γ
190
+ W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
191
+ 1)) # q_pe β†’ Ξ³
192
+
193
+ elif info.kind == 'wkv_b':
194
+ qk_nope = info.qk_nope_head_dim
195
+ kv_stride = qk_nope + info.v_head_dim
196
+ W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
197
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β†’ √γ
198
+ # v rows: not touched (k_R shared rotary unchanged)
build/torch28-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_38f9b8e_dirty
3
- ops = torch.ops._optimizer_38f9b8e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_38f9b8e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8d53b78_dirty
3
+ ops = torch.ops._optimizer_8d53b78_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8d53b78_dirty::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b1ddfe7e38a9213d5dede8052c81b78eca952aef122d4da919950ff504dc3908
3
  size 1999872
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:090f5a44cdfa4554147159cc36bb7e8ee9dba1ffb1fea4825aa838461fdaddf9
3
  size 1999872
build/torch28-cxx11-cu128-x86_64-linux/cpu_offload.py CHANGED
@@ -93,10 +93,7 @@ class CPUOffloadPool:
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
- cpu_flat = torch.empty(off,
97
- dtype=dtype,
98
- device="cpu",
99
- pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
@@ -140,8 +137,7 @@ class CPUOffloadPool:
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
- cpu_flat[off:off + n].copy_(local.reshape(-1),
144
- non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
@@ -159,8 +155,10 @@ class CPUOffloadPool:
159
  )
160
 
161
  if not self._logged:
162
- logger.info("[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
163
- offloaded_bytes / (1024**2))
 
 
164
 
165
  # ------------------------------------------------------------------
166
  def reload(self):
@@ -198,12 +196,11 @@ class CPUOffloadPool:
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
- local.reshape(-1).copy_(cpu_flat[off:off + n],
202
- non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
206
  if not self._logged:
207
- logger.info("[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)",
208
- reloaded_bytes / (1024**2))
209
- self._logged = True
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
97
  self._groups[dtype] = {
98
  "indices": indices,
99
  "offsets": offsets,
 
137
  for i, mgd_idx in enumerate(indices):
138
  local = self._local(self._managed[mgd_idx])
139
  off, n = offsets[i]
140
+ cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
 
141
 
142
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
143
 
 
155
  )
156
 
157
  if not self._logged:
158
+ logger.info(
159
+ "[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
160
+ offloaded_bytes / (1024**2),
161
+ )
162
 
163
  # ------------------------------------------------------------------
164
  def reload(self):
 
196
  for i, mgd_idx in enumerate(indices):
197
  local = self._local(self._managed[mgd_idx])
198
  off, n = offsets[i]
199
+ local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
 
200
 
201
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
202
 
203
  if not self._logged:
204
+ logger.info(
205
+ "[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)", reloaded_bytes / (1024**2)
206
+ )
build/torch28-cxx11-cu128-x86_64-linux/muon.py CHANGED
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
- qk_clip(p, scales_full, qk_clip_state.head_dim)
364
 
365
  def distributed_muon(
366
  self,
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
- qk_clip(p, scales_full, qk_clip_state.head_dim)
411
 
412
  if not dtensor_params:
413
  return
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
- raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
 
 
1054
  return super().state_dict()
1055
 
1056
  def load_state_dict(self, state_dict: dict) -> None:
1057
  if self.cpu_offload:
1058
- raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
  # Invalidate adamw.py's module-level tensor caches so that
 
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
+ qk_clip(p, scales_full, qk_clip_state)
364
 
365
  def distributed_muon(
366
  self,
 
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
+ qk_clip(p, scales_full, qk_clip_state)
411
 
412
  if not dtensor_params:
413
  return
 
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
+ raise RuntimeError(
1054
+ "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
1055
+ )
1056
  return super().state_dict()
1057
 
1058
  def load_state_dict(self, state_dict: dict) -> None:
1059
  if self.cpu_offload:
1060
+ raise RuntimeError(
1061
+ "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
1062
+ )
1063
  super().load_state_dict(state_dict)
1064
 
1065
  # Invalidate adamw.py's module-level tensor caches so that
build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py CHANGED
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
- LHS = np.array([
36
- [l, l**3, l**5, 1],
37
- [q, q**3, q**5, -1],
38
- [r, r**3, r**5, 1],
39
- [u, u**3, u**5, -1],
40
- ])
 
 
41
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
  if not np.all(np.isfinite([a, b, c, E])):
43
- raise ValueError(f"_optimal_quintic: non-finite solve result "
44
- f"a={a}, b={b}, c={c}, E={E}")
 
45
  q, r = np.sqrt(
46
- (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
47
- (10 * c))
48
  if not np.all(np.isfinite([q, r])):
49
- raise ValueError(
50
- f"_optimal_quintic: non-finite node update q={q}, r={r}")
51
  if abs(old_E - E) <= 1e-15:
52
  break
53
  else:
54
  raise RuntimeError(
55
- f"_optimal_quintic: did not converge after {max_iter} iterations")
 
56
  return float(a), float(b), float(c)
57
 
58
 
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
111
  # - Polar Express: analytically optimal per step, adapting to the shrinking
112
  # singular-value interval [l, u] as iterations progress; converges all
113
  # singular values to 1, producing the exact polar factor UV^T.
114
- _coeffs_list = _optimal_composition(l=1e-3,
115
- num_iters=10,
116
- safety_factor_eps=1e-2,
117
- cushion=0.02)
118
 
119
 
120
  # This code is adapted from:
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
148
 
149
  X = X / (X.norm() + 1e-7)
150
  hs = _coeffs_list[:steps] + list(
151
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
152
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
153
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  # Perform the NS iterations
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
183
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
184
 
185
  hs = _coeffs_list[:steps] + list(
186
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
187
  for a, b, c in hs:
188
  buf1 = torch.bmm(X, X.transpose(-2, -1))
189
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
 
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
+ LHS = np.array(
36
+ [
37
+ [l, l**3, l**5, 1],
38
+ [q, q**3, q**5, -1],
39
+ [r, r**3, r**5, 1],
40
+ [u, u**3, u**5, -1],
41
+ ]
42
+ )
43
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
44
  if not np.all(np.isfinite([a, b, c, E])):
45
+ raise ValueError(
46
+ f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
47
+ )
48
  q, r = np.sqrt(
49
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
50
+ )
51
  if not np.all(np.isfinite([q, r])):
52
+ raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
 
53
  if abs(old_E - E) <= 1e-15:
54
  break
55
  else:
56
  raise RuntimeError(
57
+ f"_optimal_quintic: did not converge after {max_iter} iterations"
58
+ )
59
  return float(a), float(b), float(c)
60
 
61
 
 
114
  # - Polar Express: analytically optimal per step, adapting to the shrinking
115
  # singular-value interval [l, u] as iterations progress; converges all
116
  # singular values to 1, producing the exact polar factor UV^T.
117
+ _coeffs_list = _optimal_composition(
118
+ l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
119
+ )
 
120
 
121
 
122
  # This code is adapted from:
 
150
 
151
  X = X / (X.norm() + 1e-7)
152
  hs = _coeffs_list[:steps] + list(
153
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
154
+ )
155
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
156
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
157
  # Perform the NS iterations
 
186
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
187
 
188
  hs = _coeffs_list[:steps] + list(
189
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
190
+ )
191
  for a, b, c in hs:
192
  buf1 = torch.bmm(X, X.transpose(-2, -1))
193
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py CHANGED
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
 
 
 
 
17
 
18
  Returns:
19
  (kind, layer_idx) or (None, -1) if not matched.
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
23
  'model.5.attn.wk.weight' -> ('wk', 5)
24
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
25
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
 
 
26
  'model.4.attn.v_proj.weight' -> (None, -1)
27
  """
28
  parts = normalize_fqn(name).split('.')
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
37
  layer_idx = int(part)
38
  break
39
 
40
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
41
  return kind, layer_idx
42
 
43
  return None, -1
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
46
  @dataclass
47
  class QKClipInfo:
48
  """Per-parameter dynamic info computed from config + runtime logits."""
49
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
50
  indices: list[int] # which heads to consider for clipping
51
- head_dim: int # from config
52
  threshold: float # from config
53
  logit: torch.Tensor | None
54
 
 
 
 
 
 
 
55
 
56
  def get_qk_clip_info(clip_config, n, qk_logits):
57
  """Extract QK clipping info for a named parameter.
58
 
59
  Args:
60
  clip_config: QK clipping configuration dict (or None).
 
 
61
  n: Parameter name string.
62
  qk_logits: Dict mapping layer indices to logit tensors (or None).
63
 
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
70
  head_dim = clip_config.get('head_dim')
71
  threshold = clip_config.get('threshold')
72
  kind, layer_idx = parse_qk_layer(n)
 
73
 
74
  logit, indices = None, []
75
  if qk_logits is not None and kind is not None:
76
  logit = qk_logits[layer_idx]
77
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
78
- indices = clip_config.get(indices_key, []) or []
79
-
80
  if isinstance(logit, DTensor):
81
  # In TP settings, qk_logits may be DTensor
82
  # We convert it to full tensor here for simplicity
83
  logit = logit.full_tensor()
84
 
85
- return QKClipInfo(
86
- kind=kind,
87
- indices=indices,
88
- head_dim=head_dim,
89
- threshold=threshold,
90
- logit=logit,
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  def compute_scales(p, qk_clip_state):
95
  """Compute per-head scaling factors for QK clipping.
96
 
97
- Returns scales tensor if any head exceeds threshold, else None.
 
98
  """
99
  kind = qk_clip_state.kind
100
  indices = qk_clip_state.indices
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
118
  if not head_scales:
119
  return None
120
 
121
- H_global = p.shape[0] // head_dim
 
 
 
 
 
 
122
  scales_full = torch.ones(H_global, device=p.data.device)
123
  for head_idx, scale in head_scales.items():
124
  scales_full[head_idx] = scale
125
  return scales_full
126
 
127
 
128
- def qk_clip(p, scales, head_dim):
129
- """Apply per-head scaling to a Q/K projection weight matrix."""
130
- if isinstance(p, torch.nn.Parameter):
131
- W = p.data.view(-1, head_dim, p.data.shape[1])
132
- W.mul_(scales.view(-1, 1, 1))
133
- else:
134
- W = p.view(-1, head_dim, p.shape[1])
135
- W.mul_(scales.view(-1, 1, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
+ and return (kind, layer_index).
17
+
18
+ Supported kinds:
19
+ MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
20
+ MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
21
 
22
  Returns:
23
  (kind, layer_idx) or (None, -1) if not matched.
 
27
  'model.5.attn.wk.weight' -> ('wk', 5)
28
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
29
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
30
+ 'model.1.attn.wq_b.weight' -> ('wq_b', 1)
31
+ 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
32
  'model.4.attn.v_proj.weight' -> (None, -1)
33
  """
34
  parts = normalize_fqn(name).split('.')
 
43
  layer_idx = int(part)
44
  break
45
 
46
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
47
  return kind, layer_idx
48
 
49
  return None, -1
 
52
  @dataclass
53
  class QKClipInfo:
54
  """Per-parameter dynamic info computed from config + runtime logits."""
55
+ kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
56
  indices: list[int] # which heads to consider for clipping
57
+ head_dim: int # from config (qk_head_dim for MLA wq_b)
58
  threshold: float # from config
59
  logit: torch.Tensor | None
60
 
61
+ # MLA-specific fields
62
+ is_mla: bool = False
63
+ qk_nope_head_dim: int = 0
64
+ qk_rope_head_dim: int = 0
65
+ v_head_dim: int = 0
66
+
67
 
68
  def get_qk_clip_info(clip_config, n, qk_logits):
69
  """Extract QK clipping info for a named parameter.
70
 
71
  Args:
72
  clip_config: QK clipping configuration dict (or None).
73
+ MHA/GQA keys: head_dim, threshold, q_indices, k_indices
74
+ MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
75
  n: Parameter name string.
76
  qk_logits: Dict mapping layer indices to logit tensors (or None).
77
 
 
84
  head_dim = clip_config.get('head_dim')
85
  threshold = clip_config.get('threshold')
86
  kind, layer_idx = parse_qk_layer(n)
87
+ is_mla = clip_config.get('is_mla', False)
88
 
89
  logit, indices = None, []
90
  if qk_logits is not None and kind is not None:
91
  logit = qk_logits[layer_idx]
 
 
 
92
  if isinstance(logit, DTensor):
93
  # In TP settings, qk_logits may be DTensor
94
  # We convert it to full tensor here for simplicity
95
  logit = logit.full_tensor()
96
 
97
+ if kind in ('wq_b', 'wq', 'q_proj'):
98
+ indices = clip_config.get('q_indices', []) or []
99
+ elif kind in ('wkv_b', 'wk', 'k_proj'):
100
+ indices = clip_config.get('k_indices', []) or []
101
+
102
+ if is_mla:
103
+ return QKClipInfo(
104
+ kind=kind,
105
+ indices=indices,
106
+ head_dim=head_dim,
107
+ threshold=threshold,
108
+ logit=logit,
109
+ is_mla=True,
110
+ qk_nope_head_dim=clip_config['qk_nope_head_dim'],
111
+ qk_rope_head_dim=clip_config['qk_rope_head_dim'],
112
+ v_head_dim=clip_config['v_head_dim'],
113
+ )
114
+ else:
115
+ return QKClipInfo(
116
+ kind=kind,
117
+ indices=indices,
118
+ head_dim=head_dim,
119
+ threshold=threshold,
120
+ logit=logit,
121
+ )
122
 
123
 
124
  def compute_scales(p, qk_clip_state):
125
  """Compute per-head scaling factors for QK clipping.
126
 
127
+ Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
128
+ For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
129
  """
130
  kind = qk_clip_state.kind
131
  indices = qk_clip_state.indices
 
149
  if not head_scales:
150
  return None
151
 
152
+ # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
153
+ if qk_clip_state.is_mla and kind == 'wkv_b':
154
+ effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
155
+ else:
156
+ effective_head_dim = head_dim
157
+
158
+ H_global = p.shape[0] // effective_head_dim
159
  scales_full = torch.ones(H_global, device=p.data.device)
160
  for head_idx, scale in head_scales.items():
161
  scales_full[head_idx] = scale
162
  return scales_full
163
 
164
 
165
+ def qk_clip(p, scales, info):
166
+ """Apply per-head scaling to a Q/K projection weight matrix.
167
+
168
+ Args:
169
+ p: Parameter (nn.Parameter or raw tensor).
170
+ scales: [n_heads] tensor, each element = √γ_h.
171
+ info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
172
+
173
+ MLA sub-region scaling per Algorithm 1 (MuonClip):
174
+ wq_b: q_nope rows β†’ √γ, q_pe rows β†’ Ξ³
175
+ wkv_b: k_nope rows β†’ √γ, v rows β†’ unchanged
176
+ """
177
+ W = p.data if isinstance(p, torch.nn.Parameter) else p
178
+
179
+ if not info.is_mla:
180
+ # MHA/GQA: uniform √γ applied to all rows in each head
181
+ W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
182
+ return
183
+
184
+ # MLA: vectorized sub-region scaling within each head
185
+ if info.kind == 'wq_b':
186
+ qk_nope = info.qk_nope_head_dim
187
+ qk_head_dim = qk_nope + info.qk_rope_head_dim
188
+ W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
189
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β†’ √γ
190
+ W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
191
+ 1)) # q_pe β†’ Ξ³
192
+
193
+ elif info.kind == 'wkv_b':
194
+ qk_nope = info.qk_nope_head_dim
195
+ kv_stride = qk_nope + info.v_head_dim
196
+ W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
197
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β†’ √γ
198
+ # v rows: not touched (k_R shared rotary unchanged)
build/torch28-cxx11-cu129-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_38f9b8e_dirty
3
- ops = torch.ops._optimizer_38f9b8e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_38f9b8e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8d53b78_dirty
3
+ ops = torch.ops._optimizer_8d53b78_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8d53b78_dirty::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:db68ba26f1b022f56a5ab4e6e0204bf26df8922750f32f21be0ad76e2674b717
3
  size 1999872
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46baa92bf8f5ec5913df4081a01f662049fda475eb01bc7ed0f6154755fa88d5
3
  size 1999872
build/torch28-cxx11-cu129-x86_64-linux/cpu_offload.py CHANGED
@@ -93,10 +93,7 @@ class CPUOffloadPool:
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
- cpu_flat = torch.empty(off,
97
- dtype=dtype,
98
- device="cpu",
99
- pin_memory=True)
100
  self._groups[dtype] = {
101
  "indices": indices,
102
  "offsets": offsets,
@@ -140,8 +137,7 @@ class CPUOffloadPool:
140
  for i, mgd_idx in enumerate(indices):
141
  local = self._local(self._managed[mgd_idx])
142
  off, n = offsets[i]
143
- cpu_flat[off:off + n].copy_(local.reshape(-1),
144
- non_blocking=True)
145
 
146
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
147
 
@@ -159,8 +155,10 @@ class CPUOffloadPool:
159
  )
160
 
161
  if not self._logged:
162
- logger.info("[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
163
- offloaded_bytes / (1024**2))
 
 
164
 
165
  # ------------------------------------------------------------------
166
  def reload(self):
@@ -198,12 +196,11 @@ class CPUOffloadPool:
198
  for i, mgd_idx in enumerate(indices):
199
  local = self._local(self._managed[mgd_idx])
200
  off, n = offsets[i]
201
- local.reshape(-1).copy_(cpu_flat[off:off + n],
202
- non_blocking=True)
203
 
204
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
205
 
206
  if not self._logged:
207
- logger.info("[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)",
208
- reloaded_bytes / (1024**2))
209
- self._logged = True
 
93
  indices.append(idx)
94
  offsets.append((off, n))
95
  off += n
96
+ cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
 
 
 
97
  self._groups[dtype] = {
98
  "indices": indices,
99
  "offsets": offsets,
 
137
  for i, mgd_idx in enumerate(indices):
138
  local = self._local(self._managed[mgd_idx])
139
  off, n = offsets[i]
140
+ cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
 
141
 
142
  offloaded_bytes += grp["total"] * cpu_flat.element_size()
143
 
 
155
  )
156
 
157
  if not self._logged:
158
+ logger.info(
159
+ "[CPUOffload] Offloaded %.2f MB (GPU β†’ CPU)",
160
+ offloaded_bytes / (1024**2),
161
+ )
162
 
163
  # ------------------------------------------------------------------
164
  def reload(self):
 
196
  for i, mgd_idx in enumerate(indices):
197
  local = self._local(self._managed[mgd_idx])
198
  off, n = offsets[i]
199
+ local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
 
200
 
201
  reloaded_bytes += grp["total"] * cpu_flat.element_size()
202
 
203
  if not self._logged:
204
+ logger.info(
205
+ "[CPUOffload] Reloaded %.2f MB (CPU β†’ GPU)", reloaded_bytes / (1024**2)
206
+ )
build/torch28-cxx11-cu129-x86_64-linux/muon.py CHANGED
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
- qk_clip(p, scales_full, qk_clip_state.head_dim)
364
 
365
  def distributed_muon(
366
  self,
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
- qk_clip(p, scales_full, qk_clip_state.head_dim)
411
 
412
  if not dtensor_params:
413
  return
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
- raise RuntimeError("Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save.")
 
 
1054
  return super().state_dict()
1055
 
1056
  def load_state_dict(self, state_dict: dict) -> None:
1057
  if self.cpu_offload:
1058
- raise RuntimeError("Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load.")
 
 
1059
  super().load_state_dict(state_dict)
1060
 
1061
  # Invalidate adamw.py's module-level tensor caches so that
 
360
  scales_full = compute_scales(
361
  p, qk_clip_state) if qk_clip_state is not None else None
362
  if scales_full is not None:
363
+ qk_clip(p, scales_full, qk_clip_state)
364
 
365
  def distributed_muon(
366
  self,
 
407
  scales_full = compute_scales(
408
  p, qk_clip_state) if qk_clip_state is not None else None
409
  if scales_full is not None:
410
+ qk_clip(p, scales_full, qk_clip_state)
411
 
412
  if not dtensor_params:
413
  return
 
1050
 
1051
  def state_dict(self) -> dict:
1052
  if self.cpu_offload:
1053
+ raise RuntimeError(
1054
+ "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
1055
+ )
1056
  return super().state_dict()
1057
 
1058
  def load_state_dict(self, state_dict: dict) -> None:
1059
  if self.cpu_offload:
1060
+ raise RuntimeError(
1061
+ "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
1062
+ )
1063
  super().load_state_dict(state_dict)
1064
 
1065
  # Invalidate adamw.py's module-level tensor caches so that
build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py CHANGED
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
- LHS = np.array([
36
- [l, l**3, l**5, 1],
37
- [q, q**3, q**5, -1],
38
- [r, r**3, r**5, 1],
39
- [u, u**3, u**5, -1],
40
- ])
 
 
41
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
42
  if not np.all(np.isfinite([a, b, c, E])):
43
- raise ValueError(f"_optimal_quintic: non-finite solve result "
44
- f"a={a}, b={b}, c={c}, E={E}")
 
45
  q, r = np.sqrt(
46
- (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
47
- (10 * c))
48
  if not np.all(np.isfinite([q, r])):
49
- raise ValueError(
50
- f"_optimal_quintic: non-finite node update q={q}, r={r}")
51
  if abs(old_E - E) <= 1e-15:
52
  break
53
  else:
54
  raise RuntimeError(
55
- f"_optimal_quintic: did not converge after {max_iter} iterations")
 
56
  return float(a), float(b), float(c)
57
 
58
 
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
111
  # - Polar Express: analytically optimal per step, adapting to the shrinking
112
  # singular-value interval [l, u] as iterations progress; converges all
113
  # singular values to 1, producing the exact polar factor UV^T.
114
- _coeffs_list = _optimal_composition(l=1e-3,
115
- num_iters=10,
116
- safety_factor_eps=1e-2,
117
- cushion=0.02)
118
 
119
 
120
  # This code is adapted from:
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
148
 
149
  X = X / (X.norm() + 1e-7)
150
  hs = _coeffs_list[:steps] + list(
151
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
152
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
153
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
154
  # Perform the NS iterations
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
183
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
184
 
185
  hs = _coeffs_list[:steps] + list(
186
- repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
 
187
  for a, b, c in hs:
188
  buf1 = torch.bmm(X, X.transpose(-2, -1))
189
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
 
32
  E = inf
33
  for _ in range(max_iter):
34
  old_E = E
35
+ LHS = np.array(
36
+ [
37
+ [l, l**3, l**5, 1],
38
+ [q, q**3, q**5, -1],
39
+ [r, r**3, r**5, 1],
40
+ [u, u**3, u**5, -1],
41
+ ]
42
+ )
43
  a, b, c, E = np.linalg.solve(LHS, np.ones(4))
44
  if not np.all(np.isfinite([a, b, c, E])):
45
+ raise ValueError(
46
+ f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
47
+ )
48
  q, r = np.sqrt(
49
+ (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
50
+ )
51
  if not np.all(np.isfinite([q, r])):
52
+ raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
 
53
  if abs(old_E - E) <= 1e-15:
54
  break
55
  else:
56
  raise RuntimeError(
57
+ f"_optimal_quintic: did not converge after {max_iter} iterations"
58
+ )
59
  return float(a), float(b), float(c)
60
 
61
 
 
114
  # - Polar Express: analytically optimal per step, adapting to the shrinking
115
  # singular-value interval [l, u] as iterations progress; converges all
116
  # singular values to 1, producing the exact polar factor UV^T.
117
+ _coeffs_list = _optimal_composition(
118
+ l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
119
+ )
 
120
 
121
 
122
  # This code is adapted from:
 
150
 
151
  X = X / (X.norm() + 1e-7)
152
  hs = _coeffs_list[:steps] + list(
153
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
154
+ )
155
  buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
156
  buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
157
  # Perform the NS iterations
 
186
  X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
187
 
188
  hs = _coeffs_list[:steps] + list(
189
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list))
190
+ )
191
  for a, b, c in hs:
192
  buf1 = torch.bmm(X, X.transpose(-2, -1))
193
  buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py CHANGED
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
- ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
 
 
 
 
17
 
18
  Returns:
19
  (kind, layer_idx) or (None, -1) if not matched.
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
23
  'model.5.attn.wk.weight' -> ('wk', 5)
24
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
25
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
 
 
26
  'model.4.attn.v_proj.weight' -> (None, -1)
27
  """
28
  parts = normalize_fqn(name).split('.')
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
37
  layer_idx = int(part)
38
  break
39
 
40
- if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
41
  return kind, layer_idx
42
 
43
  return None, -1
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
46
  @dataclass
47
  class QKClipInfo:
48
  """Per-parameter dynamic info computed from config + runtime logits."""
49
- kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
50
  indices: list[int] # which heads to consider for clipping
51
- head_dim: int # from config
52
  threshold: float # from config
53
  logit: torch.Tensor | None
54
 
 
 
 
 
 
 
55
 
56
  def get_qk_clip_info(clip_config, n, qk_logits):
57
  """Extract QK clipping info for a named parameter.
58
 
59
  Args:
60
  clip_config: QK clipping configuration dict (or None).
 
 
61
  n: Parameter name string.
62
  qk_logits: Dict mapping layer indices to logit tensors (or None).
63
 
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
70
  head_dim = clip_config.get('head_dim')
71
  threshold = clip_config.get('threshold')
72
  kind, layer_idx = parse_qk_layer(n)
 
73
 
74
  logit, indices = None, []
75
  if qk_logits is not None and kind is not None:
76
  logit = qk_logits[layer_idx]
77
- indices_key = 'q_indices' if 'q' in kind else 'k_indices'
78
- indices = clip_config.get(indices_key, []) or []
79
-
80
  if isinstance(logit, DTensor):
81
  # In TP settings, qk_logits may be DTensor
82
  # We convert it to full tensor here for simplicity
83
  logit = logit.full_tensor()
84
 
85
- return QKClipInfo(
86
- kind=kind,
87
- indices=indices,
88
- head_dim=head_dim,
89
- threshold=threshold,
90
- logit=logit,
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  def compute_scales(p, qk_clip_state):
95
  """Compute per-head scaling factors for QK clipping.
96
 
97
- Returns scales tensor if any head exceeds threshold, else None.
 
98
  """
99
  kind = qk_clip_state.kind
100
  indices = qk_clip_state.indices
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
118
  if not head_scales:
119
  return None
120
 
121
- H_global = p.shape[0] // head_dim
 
 
 
 
 
 
122
  scales_full = torch.ones(H_global, device=p.data.device)
123
  for head_idx, scale in head_scales.items():
124
  scales_full[head_idx] = scale
125
  return scales_full
126
 
127
 
128
- def qk_clip(p, scales, head_dim):
129
- """Apply per-head scaling to a Q/K projection weight matrix."""
130
- if isinstance(p, torch.nn.Parameter):
131
- W = p.data.view(-1, head_dim, p.data.shape[1])
132
- W.mul_(scales.view(-1, 1, 1))
133
- else:
134
- W = p.view(-1, head_dim, p.shape[1])
135
- W.mul_(scales.view(-1, 1, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def parse_qk_layer(name: str) -> tuple[str | None, int]:
14
  """
15
  Parse a parameter name to check if it is a query/key projection layer
16
+ and return (kind, layer_index).
17
+
18
+ Supported kinds:
19
+ MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
20
+ MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
21
 
22
  Returns:
23
  (kind, layer_idx) or (None, -1) if not matched.
 
27
  'model.5.attn.wk.weight' -> ('wk', 5)
28
  'model.2.attn.q_proj.weight' -> ('q_proj', 2)
29
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
30
+ 'model.1.attn.wq_b.weight' -> ('wq_b', 1)
31
+ 'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
32
  'model.4.attn.v_proj.weight' -> (None, -1)
33
  """
34
  parts = normalize_fqn(name).split('.')
 
43
  layer_idx = int(part)
44
  break
45
 
46
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
47
  return kind, layer_idx
48
 
49
  return None, -1
 
52
  @dataclass
53
  class QKClipInfo:
54
  """Per-parameter dynamic info computed from config + runtime logits."""
55
+ kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
56
  indices: list[int] # which heads to consider for clipping
57
+ head_dim: int # from config (qk_head_dim for MLA wq_b)
58
  threshold: float # from config
59
  logit: torch.Tensor | None
60
 
61
+ # MLA-specific fields
62
+ is_mla: bool = False
63
+ qk_nope_head_dim: int = 0
64
+ qk_rope_head_dim: int = 0
65
+ v_head_dim: int = 0
66
+
67
 
68
  def get_qk_clip_info(clip_config, n, qk_logits):
69
  """Extract QK clipping info for a named parameter.
70
 
71
  Args:
72
  clip_config: QK clipping configuration dict (or None).
73
+ MHA/GQA keys: head_dim, threshold, q_indices, k_indices
74
+ MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
75
  n: Parameter name string.
76
  qk_logits: Dict mapping layer indices to logit tensors (or None).
77
 
 
84
  head_dim = clip_config.get('head_dim')
85
  threshold = clip_config.get('threshold')
86
  kind, layer_idx = parse_qk_layer(n)
87
+ is_mla = clip_config.get('is_mla', False)
88
 
89
  logit, indices = None, []
90
  if qk_logits is not None and kind is not None:
91
  logit = qk_logits[layer_idx]
 
 
 
92
  if isinstance(logit, DTensor):
93
  # In TP settings, qk_logits may be DTensor
94
  # We convert it to full tensor here for simplicity
95
  logit = logit.full_tensor()
96
 
97
+ if kind in ('wq_b', 'wq', 'q_proj'):
98
+ indices = clip_config.get('q_indices', []) or []
99
+ elif kind in ('wkv_b', 'wk', 'k_proj'):
100
+ indices = clip_config.get('k_indices', []) or []
101
+
102
+ if is_mla:
103
+ return QKClipInfo(
104
+ kind=kind,
105
+ indices=indices,
106
+ head_dim=head_dim,
107
+ threshold=threshold,
108
+ logit=logit,
109
+ is_mla=True,
110
+ qk_nope_head_dim=clip_config['qk_nope_head_dim'],
111
+ qk_rope_head_dim=clip_config['qk_rope_head_dim'],
112
+ v_head_dim=clip_config['v_head_dim'],
113
+ )
114
+ else:
115
+ return QKClipInfo(
116
+ kind=kind,
117
+ indices=indices,
118
+ head_dim=head_dim,
119
+ threshold=threshold,
120
+ logit=logit,
121
+ )
122
 
123
 
124
  def compute_scales(p, qk_clip_state):
125
  """Compute per-head scaling factors for QK clipping.
126
 
127
+ Returns scales tensor (√γ per head) if any head exceeds threshold, else None.
128
+ For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
129
  """
130
  kind = qk_clip_state.kind
131
  indices = qk_clip_state.indices
 
149
  if not head_scales:
150
  return None
151
 
152
+ # For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
153
+ if qk_clip_state.is_mla and kind == 'wkv_b':
154
+ effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
155
+ else:
156
+ effective_head_dim = head_dim
157
+
158
+ H_global = p.shape[0] // effective_head_dim
159
  scales_full = torch.ones(H_global, device=p.data.device)
160
  for head_idx, scale in head_scales.items():
161
  scales_full[head_idx] = scale
162
  return scales_full
163
 
164
 
165
+ def qk_clip(p, scales, info):
166
+ """Apply per-head scaling to a Q/K projection weight matrix.
167
+
168
+ Args:
169
+ p: Parameter (nn.Parameter or raw tensor).
170
+ scales: [n_heads] tensor, each element = √γ_h.
171
+ info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
172
+
173
+ MLA sub-region scaling per Algorithm 1 (MuonClip):
174
+ wq_b: q_nope rows β†’ √γ, q_pe rows β†’ Ξ³
175
+ wkv_b: k_nope rows β†’ √γ, v rows β†’ unchanged
176
+ """
177
+ W = p.data if isinstance(p, torch.nn.Parameter) else p
178
+
179
+ if not info.is_mla:
180
+ # MHA/GQA: uniform √γ applied to all rows in each head
181
+ W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
182
+ return
183
+
184
+ # MLA: vectorized sub-region scaling within each head
185
+ if info.kind == 'wq_b':
186
+ qk_nope = info.qk_nope_head_dim
187
+ qk_head_dim = qk_nope + info.qk_rope_head_dim
188
+ W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
189
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β†’ √γ
190
+ W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
191
+ 1)) # q_pe β†’ Ξ³
192
+
193
+ elif info.kind == 'wkv_b':
194
+ qk_nope = info.qk_nope_head_dim
195
+ kv_stride = qk_nope + info.v_head_dim
196
+ W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
197
+ W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β†’ √γ
198
+ # v rows: not touched (k_R shared rotary unchanged)
build/torch28-cxx11-rocm63-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_38f9b8e_dirty
3
- ops = torch.ops._optimizer_38f9b8e_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_38f9b8e_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_8d53b78_dirty
3
+ ops = torch.ops._optimizer_8d53b78_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_8d53b78_dirty::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β†’ _optimizer_8d53b78_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5a6a8788f055b22d594330fc06487ae2c6eeb2b64e0ab0132b68036a78560cf6
3
  size 1865080
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcf5b8838dfaf6e81fdbd52ff4638ca76abaa678f7c2cbd81cf03dc72f9cd5d2
3
  size 1865080