feat: extend QK-Clip to support MLA (MuonClip Algorithm 1) [skip-build] (#28)
Browse files* feat: extend QK-Clip to support MLA (MuonClip Algorithm 1) [skip-build]
Add MLA support to qk_clip.py following the MuonClip spec (docs/muon-clip.md):
- parse_qk_layer: recognize 'wq_b' and 'wkv_b' MLA weight names
- QKClipInfo: add is_mla, qk_nope_head_dim, qk_rope_head_dim, v_head_dim fields
- get_qk_clip_info: branch on is_mla flag in clip_config
- compute_scales: use kv_stride (qk_nope + v_head_dim) as effective head dim for wkv_b
- qk_clip: simplify signature to (p, scales, info); vectorize MLA sub-region
scaling via tensor reshape instead of Python per-head loops:
wq_b: q_nope rows β βΞ³, q_pe rows β Ξ³
wkv_b: k_nope rows β βΞ³, v rows unchanged
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* style: ruff format (pre-commit fix)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add built binary [skip-build]
---------
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: WyldeCat <skan1543@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
- build/torch210-cxx11-cu126-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu126-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β _optimizer_8d53b78_dirty.abi3.so} +1 -1
- build/torch210-cxx11-cu126-x86_64-linux/cpu_offload.py +10 -13
- build/torch210-cxx11-cu126-x86_64-linux/muon.py +8 -4
- build/torch210-cxx11-cu126-x86_64-linux/newton_schulz.py +23 -19
- build/torch210-cxx11-cu126-x86_64-linux/qk_clip.py +87 -24
- build/torch210-cxx11-cu128-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu128-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β _optimizer_8d53b78_dirty.abi3.so} +1 -1
- build/torch210-cxx11-cu128-x86_64-linux/cpu_offload.py +10 -13
- build/torch210-cxx11-cu128-x86_64-linux/muon.py +8 -4
- build/torch210-cxx11-cu128-x86_64-linux/newton_schulz.py +23 -19
- build/torch210-cxx11-cu128-x86_64-linux/qk_clip.py +87 -24
- build/torch210-cxx11-cu130-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu130-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β _optimizer_8d53b78_dirty.abi3.so} +1 -1
- build/torch210-cxx11-cu130-x86_64-linux/cpu_offload.py +10 -13
- build/torch210-cxx11-cu130-x86_64-linux/muon.py +8 -4
- build/torch210-cxx11-cu130-x86_64-linux/newton_schulz.py +23 -19
- build/torch210-cxx11-cu130-x86_64-linux/qk_clip.py +87 -24
- build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-rocm70-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β _optimizer_8d53b78_dirty.abi3.so} +1 -1
- build/torch210-cxx11-rocm70-x86_64-linux/cpu_offload.py +10 -13
- build/torch210-cxx11-rocm70-x86_64-linux/muon.py +8 -4
- build/torch210-cxx11-rocm70-x86_64-linux/newton_schulz.py +23 -19
- build/torch210-cxx11-rocm70-x86_64-linux/qk_clip.py +87 -24
- build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-rocm71-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β _optimizer_8d53b78_dirty.abi3.so} +1 -1
- build/torch210-cxx11-rocm71-x86_64-linux/cpu_offload.py +10 -13
- build/torch210-cxx11-rocm71-x86_64-linux/muon.py +8 -4
- build/torch210-cxx11-rocm71-x86_64-linux/newton_schulz.py +23 -19
- build/torch210-cxx11-rocm71-x86_64-linux/qk_clip.py +87 -24
- build/torch28-cxx11-cu126-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-cu126-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β _optimizer_8d53b78_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu126-x86_64-linux/cpu_offload.py +10 -13
- build/torch28-cxx11-cu126-x86_64-linux/muon.py +8 -4
- build/torch28-cxx11-cu126-x86_64-linux/newton_schulz.py +23 -19
- build/torch28-cxx11-cu126-x86_64-linux/qk_clip.py +87 -24
- build/torch28-cxx11-cu128-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-cu128-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β _optimizer_8d53b78_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu128-x86_64-linux/cpu_offload.py +10 -13
- build/torch28-cxx11-cu128-x86_64-linux/muon.py +8 -4
- build/torch28-cxx11-cu128-x86_64-linux/newton_schulz.py +23 -19
- build/torch28-cxx11-cu128-x86_64-linux/qk_clip.py +87 -24
- build/torch28-cxx11-cu129-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-cu129-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β _optimizer_8d53b78_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu129-x86_64-linux/cpu_offload.py +10 -13
- build/torch28-cxx11-cu129-x86_64-linux/muon.py +8 -4
- build/torch28-cxx11-cu129-x86_64-linux/newton_schulz.py +23 -19
- build/torch28-cxx11-cu129-x86_64-linux/qk_clip.py +87 -24
- build/torch28-cxx11-rocm63-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-rocm63-x86_64-linux/{_optimizer_38f9b8e_dirty.abi3.so β _optimizer_8d53b78_dirty.abi3.so} +1 -1
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8d53b78_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8d53b78_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8d53b78_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1940944
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:075fc73dbb2750aed7598cc3e13b593b6b1e7a78a78491e1b852fbd2a9af8f8d
|
| 3 |
size 1940944
|
|
@@ -93,10 +93,7 @@ class CPUOffloadPool:
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
-
cpu_flat = torch.empty(off,
|
| 97 |
-
dtype=dtype,
|
| 98 |
-
device="cpu",
|
| 99 |
-
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
@@ -140,8 +137,7 @@ class CPUOffloadPool:
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
-
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
-
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
|
@@ -159,8 +155,10 @@ class CPUOffloadPool:
|
|
| 159 |
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
-
logger.info(
|
| 163 |
-
|
|
|
|
|
|
|
| 164 |
|
| 165 |
# ------------------------------------------------------------------
|
| 166 |
def reload(self):
|
|
@@ -198,12 +196,11 @@ class CPUOffloadPool:
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
-
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
-
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
| 206 |
if not self._logged:
|
| 207 |
-
logger.info(
|
| 208 |
-
|
| 209 |
-
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
|
|
|
|
|
|
|
|
|
|
| 97 |
self._groups[dtype] = {
|
| 98 |
"indices": indices,
|
| 99 |
"offsets": offsets,
|
|
|
|
| 137 |
for i, mgd_idx in enumerate(indices):
|
| 138 |
local = self._local(self._managed[mgd_idx])
|
| 139 |
off, n = offsets[i]
|
| 140 |
+
cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
|
|
|
|
| 141 |
|
| 142 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 143 |
|
|
|
|
| 155 |
)
|
| 156 |
|
| 157 |
if not self._logged:
|
| 158 |
+
logger.info(
|
| 159 |
+
"[CPUOffload] Offloaded %.2f MB (GPU β CPU)",
|
| 160 |
+
offloaded_bytes / (1024**2),
|
| 161 |
+
)
|
| 162 |
|
| 163 |
# ------------------------------------------------------------------
|
| 164 |
def reload(self):
|
|
|
|
| 196 |
for i, mgd_idx in enumerate(indices):
|
| 197 |
local = self._local(self._managed[mgd_idx])
|
| 198 |
off, n = offsets[i]
|
| 199 |
+
local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
|
|
|
|
| 200 |
|
| 201 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 202 |
|
| 203 |
if not self._logged:
|
| 204 |
+
logger.info(
|
| 205 |
+
"[CPUOffload] Reloaded %.2f MB (CPU β GPU)", reloaded_bytes / (1024**2)
|
| 206 |
+
)
|
|
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1054 |
return super().state_dict()
|
| 1055 |
|
| 1056 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
if self.cpu_offload:
|
| 1058 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError(
|
| 1054 |
+
"Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
|
| 1055 |
+
)
|
| 1056 |
return super().state_dict()
|
| 1057 |
|
| 1058 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1059 |
if self.cpu_offload:
|
| 1060 |
+
raise RuntimeError(
|
| 1061 |
+
"Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
|
| 1062 |
+
)
|
| 1063 |
super().load_state_dict(state_dict)
|
| 1064 |
|
| 1065 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
-
LHS = np.array(
|
| 36 |
-
[
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 42 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 43 |
-
raise ValueError(
|
| 44 |
-
|
|
|
|
| 45 |
q, r = np.sqrt(
|
| 46 |
-
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
|
| 47 |
-
|
| 48 |
if not np.all(np.isfinite([q, r])):
|
| 49 |
-
raise ValueError(
|
| 50 |
-
f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
| 51 |
if abs(old_E - E) <= 1e-15:
|
| 52 |
break
|
| 53 |
else:
|
| 54 |
raise RuntimeError(
|
| 55 |
-
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
|
|
|
| 56 |
return float(a), float(b), float(c)
|
| 57 |
|
| 58 |
|
|
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
|
|
| 111 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 112 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 113 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 114 |
-
_coeffs_list = _optimal_composition(
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
cushion=0.02)
|
| 118 |
|
| 119 |
|
| 120 |
# This code is adapted from:
|
|
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 148 |
|
| 149 |
X = X / (X.norm() + 1e-7)
|
| 150 |
hs = _coeffs_list[:steps] + list(
|
| 151 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 152 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 153 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 154 |
# Perform the NS iterations
|
|
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
|
|
| 183 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 184 |
|
| 185 |
hs = _coeffs_list[:steps] + list(
|
| 186 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 187 |
for a, b, c in hs:
|
| 188 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 189 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
+
LHS = np.array(
|
| 36 |
+
[
|
| 37 |
+
[l, l**3, l**5, 1],
|
| 38 |
+
[q, q**3, q**5, -1],
|
| 39 |
+
[r, r**3, r**5, 1],
|
| 40 |
+
[u, u**3, u**5, -1],
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 44 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
|
| 47 |
+
)
|
| 48 |
q, r = np.sqrt(
|
| 49 |
+
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
|
| 50 |
+
)
|
| 51 |
if not np.all(np.isfinite([q, r])):
|
| 52 |
+
raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
|
|
|
| 53 |
if abs(old_E - E) <= 1e-15:
|
| 54 |
break
|
| 55 |
else:
|
| 56 |
raise RuntimeError(
|
| 57 |
+
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
| 58 |
+
)
|
| 59 |
return float(a), float(b), float(c)
|
| 60 |
|
| 61 |
|
|
|
|
| 114 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 115 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 116 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 117 |
+
_coeffs_list = _optimal_composition(
|
| 118 |
+
l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
|
| 119 |
+
)
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
# This code is adapted from:
|
|
|
|
| 150 |
|
| 151 |
X = X / (X.norm() + 1e-7)
|
| 152 |
hs = _coeffs_list[:steps] + list(
|
| 153 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 154 |
+
)
|
| 155 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 156 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 157 |
# Perform the NS iterations
|
|
|
|
| 186 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 187 |
|
| 188 |
hs = _coeffs_list[:steps] + list(
|
| 189 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 190 |
+
)
|
| 191 |
for a, b, c in hs:
|
| 192 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 193 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
Returns:
|
| 19 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 23 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 24 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 25 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
|
|
|
|
|
|
| 26 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 27 |
"""
|
| 28 |
parts = normalize_fqn(name).split('.')
|
|
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 37 |
layer_idx = int(part)
|
| 38 |
break
|
| 39 |
|
| 40 |
-
if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
|
| 41 |
return kind, layer_idx
|
| 42 |
|
| 43 |
return None, -1
|
|
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 46 |
@dataclass
|
| 47 |
class QKClipInfo:
|
| 48 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 49 |
-
kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
|
| 50 |
indices: list[int] # which heads to consider for clipping
|
| 51 |
-
head_dim: int # from config
|
| 52 |
threshold: float # from config
|
| 53 |
logit: torch.Tensor | None
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 57 |
"""Extract QK clipping info for a named parameter.
|
| 58 |
|
| 59 |
Args:
|
| 60 |
clip_config: QK clipping configuration dict (or None).
|
|
|
|
|
|
|
| 61 |
n: Parameter name string.
|
| 62 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 63 |
|
|
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
|
|
| 70 |
head_dim = clip_config.get('head_dim')
|
| 71 |
threshold = clip_config.get('threshold')
|
| 72 |
kind, layer_idx = parse_qk_layer(n)
|
|
|
|
| 73 |
|
| 74 |
logit, indices = None, []
|
| 75 |
if qk_logits is not None and kind is not None:
|
| 76 |
logit = qk_logits[layer_idx]
|
| 77 |
-
indices_key = 'q_indices' if 'q' in kind else 'k_indices'
|
| 78 |
-
indices = clip_config.get(indices_key, []) or []
|
| 79 |
-
|
| 80 |
if isinstance(logit, DTensor):
|
| 81 |
# In TP settings, qk_logits may be DTensor
|
| 82 |
# We convert it to full tensor here for simplicity
|
| 83 |
logit = logit.full_tensor()
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
def compute_scales(p, qk_clip_state):
|
| 95 |
"""Compute per-head scaling factors for QK clipping.
|
| 96 |
|
| 97 |
-
Returns scales tensor if any head exceeds threshold, else None.
|
|
|
|
| 98 |
"""
|
| 99 |
kind = qk_clip_state.kind
|
| 100 |
indices = qk_clip_state.indices
|
|
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
|
|
| 118 |
if not head_scales:
|
| 119 |
return None
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 123 |
for head_idx, scale in head_scales.items():
|
| 124 |
scales_full[head_idx] = scale
|
| 125 |
return scales_full
|
| 126 |
|
| 127 |
|
| 128 |
-
def qk_clip(p, scales,
|
| 129 |
-
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
+
and return (kind, layer_index).
|
| 17 |
+
|
| 18 |
+
Supported kinds:
|
| 19 |
+
MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
|
| 20 |
+
MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
|
| 21 |
|
| 22 |
Returns:
|
| 23 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
|
|
| 27 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 28 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 29 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
| 30 |
+
'model.1.attn.wq_b.weight' -> ('wq_b', 1)
|
| 31 |
+
'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
|
| 32 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 33 |
"""
|
| 34 |
parts = normalize_fqn(name).split('.')
|
|
|
|
| 43 |
layer_idx = int(part)
|
| 44 |
break
|
| 45 |
|
| 46 |
+
if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
|
| 47 |
return kind, layer_idx
|
| 48 |
|
| 49 |
return None, -1
|
|
|
|
| 52 |
@dataclass
|
| 53 |
class QKClipInfo:
|
| 54 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 55 |
+
kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
|
| 56 |
indices: list[int] # which heads to consider for clipping
|
| 57 |
+
head_dim: int # from config (qk_head_dim for MLA wq_b)
|
| 58 |
threshold: float # from config
|
| 59 |
logit: torch.Tensor | None
|
| 60 |
|
| 61 |
+
# MLA-specific fields
|
| 62 |
+
is_mla: bool = False
|
| 63 |
+
qk_nope_head_dim: int = 0
|
| 64 |
+
qk_rope_head_dim: int = 0
|
| 65 |
+
v_head_dim: int = 0
|
| 66 |
+
|
| 67 |
|
| 68 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 69 |
"""Extract QK clipping info for a named parameter.
|
| 70 |
|
| 71 |
Args:
|
| 72 |
clip_config: QK clipping configuration dict (or None).
|
| 73 |
+
MHA/GQA keys: head_dim, threshold, q_indices, k_indices
|
| 74 |
+
MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
|
| 75 |
n: Parameter name string.
|
| 76 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 77 |
|
|
|
|
| 84 |
head_dim = clip_config.get('head_dim')
|
| 85 |
threshold = clip_config.get('threshold')
|
| 86 |
kind, layer_idx = parse_qk_layer(n)
|
| 87 |
+
is_mla = clip_config.get('is_mla', False)
|
| 88 |
|
| 89 |
logit, indices = None, []
|
| 90 |
if qk_logits is not None and kind is not None:
|
| 91 |
logit = qk_logits[layer_idx]
|
|
|
|
|
|
|
|
|
|
| 92 |
if isinstance(logit, DTensor):
|
| 93 |
# In TP settings, qk_logits may be DTensor
|
| 94 |
# We convert it to full tensor here for simplicity
|
| 95 |
logit = logit.full_tensor()
|
| 96 |
|
| 97 |
+
if kind in ('wq_b', 'wq', 'q_proj'):
|
| 98 |
+
indices = clip_config.get('q_indices', []) or []
|
| 99 |
+
elif kind in ('wkv_b', 'wk', 'k_proj'):
|
| 100 |
+
indices = clip_config.get('k_indices', []) or []
|
| 101 |
+
|
| 102 |
+
if is_mla:
|
| 103 |
+
return QKClipInfo(
|
| 104 |
+
kind=kind,
|
| 105 |
+
indices=indices,
|
| 106 |
+
head_dim=head_dim,
|
| 107 |
+
threshold=threshold,
|
| 108 |
+
logit=logit,
|
| 109 |
+
is_mla=True,
|
| 110 |
+
qk_nope_head_dim=clip_config['qk_nope_head_dim'],
|
| 111 |
+
qk_rope_head_dim=clip_config['qk_rope_head_dim'],
|
| 112 |
+
v_head_dim=clip_config['v_head_dim'],
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
return QKClipInfo(
|
| 116 |
+
kind=kind,
|
| 117 |
+
indices=indices,
|
| 118 |
+
head_dim=head_dim,
|
| 119 |
+
threshold=threshold,
|
| 120 |
+
logit=logit,
|
| 121 |
+
)
|
| 122 |
|
| 123 |
|
| 124 |
def compute_scales(p, qk_clip_state):
|
| 125 |
"""Compute per-head scaling factors for QK clipping.
|
| 126 |
|
| 127 |
+
Returns scales tensor (βΞ³ per head) if any head exceeds threshold, else None.
|
| 128 |
+
For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
|
| 129 |
"""
|
| 130 |
kind = qk_clip_state.kind
|
| 131 |
indices = qk_clip_state.indices
|
|
|
|
| 149 |
if not head_scales:
|
| 150 |
return None
|
| 151 |
|
| 152 |
+
# For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
|
| 153 |
+
if qk_clip_state.is_mla and kind == 'wkv_b':
|
| 154 |
+
effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
|
| 155 |
+
else:
|
| 156 |
+
effective_head_dim = head_dim
|
| 157 |
+
|
| 158 |
+
H_global = p.shape[0] // effective_head_dim
|
| 159 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 160 |
for head_idx, scale in head_scales.items():
|
| 161 |
scales_full[head_idx] = scale
|
| 162 |
return scales_full
|
| 163 |
|
| 164 |
|
| 165 |
+
def qk_clip(p, scales, info):
|
| 166 |
+
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
p: Parameter (nn.Parameter or raw tensor).
|
| 170 |
+
scales: [n_heads] tensor, each element = βΞ³_h.
|
| 171 |
+
info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
|
| 172 |
+
|
| 173 |
+
MLA sub-region scaling per Algorithm 1 (MuonClip):
|
| 174 |
+
wq_b: q_nope rows β βΞ³, q_pe rows β Ξ³
|
| 175 |
+
wkv_b: k_nope rows β βΞ³, v rows β unchanged
|
| 176 |
+
"""
|
| 177 |
+
W = p.data if isinstance(p, torch.nn.Parameter) else p
|
| 178 |
+
|
| 179 |
+
if not info.is_mla:
|
| 180 |
+
# MHA/GQA: uniform βΞ³ applied to all rows in each head
|
| 181 |
+
W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
# MLA: vectorized sub-region scaling within each head
|
| 185 |
+
if info.kind == 'wq_b':
|
| 186 |
+
qk_nope = info.qk_nope_head_dim
|
| 187 |
+
qk_head_dim = qk_nope + info.qk_rope_head_dim
|
| 188 |
+
W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
|
| 189 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β βΞ³
|
| 190 |
+
W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
|
| 191 |
+
1)) # q_pe β Ξ³
|
| 192 |
+
|
| 193 |
+
elif info.kind == 'wkv_b':
|
| 194 |
+
qk_nope = info.qk_nope_head_dim
|
| 195 |
+
kv_stride = qk_nope + info.v_head_dim
|
| 196 |
+
W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
|
| 197 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β βΞ³
|
| 198 |
+
# v rows: not touched (k_R shared rotary unchanged)
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8d53b78_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8d53b78_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8d53b78_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 2004144
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2af397ae01c8c01ee0e879f6812bd9df55d152afbcc6713f5c1987d5bce7793b
|
| 3 |
size 2004144
|
|
@@ -93,10 +93,7 @@ class CPUOffloadPool:
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
-
cpu_flat = torch.empty(off,
|
| 97 |
-
dtype=dtype,
|
| 98 |
-
device="cpu",
|
| 99 |
-
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
@@ -140,8 +137,7 @@ class CPUOffloadPool:
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
-
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
-
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
|
@@ -159,8 +155,10 @@ class CPUOffloadPool:
|
|
| 159 |
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
-
logger.info(
|
| 163 |
-
|
|
|
|
|
|
|
| 164 |
|
| 165 |
# ------------------------------------------------------------------
|
| 166 |
def reload(self):
|
|
@@ -198,12 +196,11 @@ class CPUOffloadPool:
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
-
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
-
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
| 206 |
if not self._logged:
|
| 207 |
-
logger.info(
|
| 208 |
-
|
| 209 |
-
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
|
|
|
|
|
|
|
|
|
|
| 97 |
self._groups[dtype] = {
|
| 98 |
"indices": indices,
|
| 99 |
"offsets": offsets,
|
|
|
|
| 137 |
for i, mgd_idx in enumerate(indices):
|
| 138 |
local = self._local(self._managed[mgd_idx])
|
| 139 |
off, n = offsets[i]
|
| 140 |
+
cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
|
|
|
|
| 141 |
|
| 142 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 143 |
|
|
|
|
| 155 |
)
|
| 156 |
|
| 157 |
if not self._logged:
|
| 158 |
+
logger.info(
|
| 159 |
+
"[CPUOffload] Offloaded %.2f MB (GPU β CPU)",
|
| 160 |
+
offloaded_bytes / (1024**2),
|
| 161 |
+
)
|
| 162 |
|
| 163 |
# ------------------------------------------------------------------
|
| 164 |
def reload(self):
|
|
|
|
| 196 |
for i, mgd_idx in enumerate(indices):
|
| 197 |
local = self._local(self._managed[mgd_idx])
|
| 198 |
off, n = offsets[i]
|
| 199 |
+
local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
|
|
|
|
| 200 |
|
| 201 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 202 |
|
| 203 |
if not self._logged:
|
| 204 |
+
logger.info(
|
| 205 |
+
"[CPUOffload] Reloaded %.2f MB (CPU β GPU)", reloaded_bytes / (1024**2)
|
| 206 |
+
)
|
|
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1054 |
return super().state_dict()
|
| 1055 |
|
| 1056 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
if self.cpu_offload:
|
| 1058 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError(
|
| 1054 |
+
"Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
|
| 1055 |
+
)
|
| 1056 |
return super().state_dict()
|
| 1057 |
|
| 1058 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1059 |
if self.cpu_offload:
|
| 1060 |
+
raise RuntimeError(
|
| 1061 |
+
"Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
|
| 1062 |
+
)
|
| 1063 |
super().load_state_dict(state_dict)
|
| 1064 |
|
| 1065 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
-
LHS = np.array(
|
| 36 |
-
[
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 42 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 43 |
-
raise ValueError(
|
| 44 |
-
|
|
|
|
| 45 |
q, r = np.sqrt(
|
| 46 |
-
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
|
| 47 |
-
|
| 48 |
if not np.all(np.isfinite([q, r])):
|
| 49 |
-
raise ValueError(
|
| 50 |
-
f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
| 51 |
if abs(old_E - E) <= 1e-15:
|
| 52 |
break
|
| 53 |
else:
|
| 54 |
raise RuntimeError(
|
| 55 |
-
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
|
|
|
| 56 |
return float(a), float(b), float(c)
|
| 57 |
|
| 58 |
|
|
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
|
|
| 111 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 112 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 113 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 114 |
-
_coeffs_list = _optimal_composition(
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
cushion=0.02)
|
| 118 |
|
| 119 |
|
| 120 |
# This code is adapted from:
|
|
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 148 |
|
| 149 |
X = X / (X.norm() + 1e-7)
|
| 150 |
hs = _coeffs_list[:steps] + list(
|
| 151 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 152 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 153 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 154 |
# Perform the NS iterations
|
|
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
|
|
| 183 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 184 |
|
| 185 |
hs = _coeffs_list[:steps] + list(
|
| 186 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 187 |
for a, b, c in hs:
|
| 188 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 189 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
+
LHS = np.array(
|
| 36 |
+
[
|
| 37 |
+
[l, l**3, l**5, 1],
|
| 38 |
+
[q, q**3, q**5, -1],
|
| 39 |
+
[r, r**3, r**5, 1],
|
| 40 |
+
[u, u**3, u**5, -1],
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 44 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
|
| 47 |
+
)
|
| 48 |
q, r = np.sqrt(
|
| 49 |
+
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
|
| 50 |
+
)
|
| 51 |
if not np.all(np.isfinite([q, r])):
|
| 52 |
+
raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
|
|
|
| 53 |
if abs(old_E - E) <= 1e-15:
|
| 54 |
break
|
| 55 |
else:
|
| 56 |
raise RuntimeError(
|
| 57 |
+
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
| 58 |
+
)
|
| 59 |
return float(a), float(b), float(c)
|
| 60 |
|
| 61 |
|
|
|
|
| 114 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 115 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 116 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 117 |
+
_coeffs_list = _optimal_composition(
|
| 118 |
+
l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
|
| 119 |
+
)
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
# This code is adapted from:
|
|
|
|
| 150 |
|
| 151 |
X = X / (X.norm() + 1e-7)
|
| 152 |
hs = _coeffs_list[:steps] + list(
|
| 153 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 154 |
+
)
|
| 155 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 156 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 157 |
# Perform the NS iterations
|
|
|
|
| 186 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 187 |
|
| 188 |
hs = _coeffs_list[:steps] + list(
|
| 189 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 190 |
+
)
|
| 191 |
for a, b, c in hs:
|
| 192 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 193 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
Returns:
|
| 19 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 23 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 24 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 25 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
|
|
|
|
|
|
| 26 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 27 |
"""
|
| 28 |
parts = normalize_fqn(name).split('.')
|
|
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 37 |
layer_idx = int(part)
|
| 38 |
break
|
| 39 |
|
| 40 |
-
if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
|
| 41 |
return kind, layer_idx
|
| 42 |
|
| 43 |
return None, -1
|
|
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 46 |
@dataclass
|
| 47 |
class QKClipInfo:
|
| 48 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 49 |
-
kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
|
| 50 |
indices: list[int] # which heads to consider for clipping
|
| 51 |
-
head_dim: int # from config
|
| 52 |
threshold: float # from config
|
| 53 |
logit: torch.Tensor | None
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 57 |
"""Extract QK clipping info for a named parameter.
|
| 58 |
|
| 59 |
Args:
|
| 60 |
clip_config: QK clipping configuration dict (or None).
|
|
|
|
|
|
|
| 61 |
n: Parameter name string.
|
| 62 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 63 |
|
|
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
|
|
| 70 |
head_dim = clip_config.get('head_dim')
|
| 71 |
threshold = clip_config.get('threshold')
|
| 72 |
kind, layer_idx = parse_qk_layer(n)
|
|
|
|
| 73 |
|
| 74 |
logit, indices = None, []
|
| 75 |
if qk_logits is not None and kind is not None:
|
| 76 |
logit = qk_logits[layer_idx]
|
| 77 |
-
indices_key = 'q_indices' if 'q' in kind else 'k_indices'
|
| 78 |
-
indices = clip_config.get(indices_key, []) or []
|
| 79 |
-
|
| 80 |
if isinstance(logit, DTensor):
|
| 81 |
# In TP settings, qk_logits may be DTensor
|
| 82 |
# We convert it to full tensor here for simplicity
|
| 83 |
logit = logit.full_tensor()
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
def compute_scales(p, qk_clip_state):
|
| 95 |
"""Compute per-head scaling factors for QK clipping.
|
| 96 |
|
| 97 |
-
Returns scales tensor if any head exceeds threshold, else None.
|
|
|
|
| 98 |
"""
|
| 99 |
kind = qk_clip_state.kind
|
| 100 |
indices = qk_clip_state.indices
|
|
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
|
|
| 118 |
if not head_scales:
|
| 119 |
return None
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 123 |
for head_idx, scale in head_scales.items():
|
| 124 |
scales_full[head_idx] = scale
|
| 125 |
return scales_full
|
| 126 |
|
| 127 |
|
| 128 |
-
def qk_clip(p, scales,
|
| 129 |
-
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
+
and return (kind, layer_index).
|
| 17 |
+
|
| 18 |
+
Supported kinds:
|
| 19 |
+
MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
|
| 20 |
+
MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
|
| 21 |
|
| 22 |
Returns:
|
| 23 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
|
|
| 27 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 28 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 29 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
| 30 |
+
'model.1.attn.wq_b.weight' -> ('wq_b', 1)
|
| 31 |
+
'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
|
| 32 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 33 |
"""
|
| 34 |
parts = normalize_fqn(name).split('.')
|
|
|
|
| 43 |
layer_idx = int(part)
|
| 44 |
break
|
| 45 |
|
| 46 |
+
if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
|
| 47 |
return kind, layer_idx
|
| 48 |
|
| 49 |
return None, -1
|
|
|
|
| 52 |
@dataclass
|
| 53 |
class QKClipInfo:
|
| 54 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 55 |
+
kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
|
| 56 |
indices: list[int] # which heads to consider for clipping
|
| 57 |
+
head_dim: int # from config (qk_head_dim for MLA wq_b)
|
| 58 |
threshold: float # from config
|
| 59 |
logit: torch.Tensor | None
|
| 60 |
|
| 61 |
+
# MLA-specific fields
|
| 62 |
+
is_mla: bool = False
|
| 63 |
+
qk_nope_head_dim: int = 0
|
| 64 |
+
qk_rope_head_dim: int = 0
|
| 65 |
+
v_head_dim: int = 0
|
| 66 |
+
|
| 67 |
|
| 68 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 69 |
"""Extract QK clipping info for a named parameter.
|
| 70 |
|
| 71 |
Args:
|
| 72 |
clip_config: QK clipping configuration dict (or None).
|
| 73 |
+
MHA/GQA keys: head_dim, threshold, q_indices, k_indices
|
| 74 |
+
MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
|
| 75 |
n: Parameter name string.
|
| 76 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 77 |
|
|
|
|
| 84 |
head_dim = clip_config.get('head_dim')
|
| 85 |
threshold = clip_config.get('threshold')
|
| 86 |
kind, layer_idx = parse_qk_layer(n)
|
| 87 |
+
is_mla = clip_config.get('is_mla', False)
|
| 88 |
|
| 89 |
logit, indices = None, []
|
| 90 |
if qk_logits is not None and kind is not None:
|
| 91 |
logit = qk_logits[layer_idx]
|
|
|
|
|
|
|
|
|
|
| 92 |
if isinstance(logit, DTensor):
|
| 93 |
# In TP settings, qk_logits may be DTensor
|
| 94 |
# We convert it to full tensor here for simplicity
|
| 95 |
logit = logit.full_tensor()
|
| 96 |
|
| 97 |
+
if kind in ('wq_b', 'wq', 'q_proj'):
|
| 98 |
+
indices = clip_config.get('q_indices', []) or []
|
| 99 |
+
elif kind in ('wkv_b', 'wk', 'k_proj'):
|
| 100 |
+
indices = clip_config.get('k_indices', []) or []
|
| 101 |
+
|
| 102 |
+
if is_mla:
|
| 103 |
+
return QKClipInfo(
|
| 104 |
+
kind=kind,
|
| 105 |
+
indices=indices,
|
| 106 |
+
head_dim=head_dim,
|
| 107 |
+
threshold=threshold,
|
| 108 |
+
logit=logit,
|
| 109 |
+
is_mla=True,
|
| 110 |
+
qk_nope_head_dim=clip_config['qk_nope_head_dim'],
|
| 111 |
+
qk_rope_head_dim=clip_config['qk_rope_head_dim'],
|
| 112 |
+
v_head_dim=clip_config['v_head_dim'],
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
return QKClipInfo(
|
| 116 |
+
kind=kind,
|
| 117 |
+
indices=indices,
|
| 118 |
+
head_dim=head_dim,
|
| 119 |
+
threshold=threshold,
|
| 120 |
+
logit=logit,
|
| 121 |
+
)
|
| 122 |
|
| 123 |
|
| 124 |
def compute_scales(p, qk_clip_state):
|
| 125 |
"""Compute per-head scaling factors for QK clipping.
|
| 126 |
|
| 127 |
+
Returns scales tensor (βΞ³ per head) if any head exceeds threshold, else None.
|
| 128 |
+
For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
|
| 129 |
"""
|
| 130 |
kind = qk_clip_state.kind
|
| 131 |
indices = qk_clip_state.indices
|
|
|
|
| 149 |
if not head_scales:
|
| 150 |
return None
|
| 151 |
|
| 152 |
+
# For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
|
| 153 |
+
if qk_clip_state.is_mla and kind == 'wkv_b':
|
| 154 |
+
effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
|
| 155 |
+
else:
|
| 156 |
+
effective_head_dim = head_dim
|
| 157 |
+
|
| 158 |
+
H_global = p.shape[0] // effective_head_dim
|
| 159 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 160 |
for head_idx, scale in head_scales.items():
|
| 161 |
scales_full[head_idx] = scale
|
| 162 |
return scales_full
|
| 163 |
|
| 164 |
|
| 165 |
+
def qk_clip(p, scales, info):
|
| 166 |
+
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
p: Parameter (nn.Parameter or raw tensor).
|
| 170 |
+
scales: [n_heads] tensor, each element = βΞ³_h.
|
| 171 |
+
info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
|
| 172 |
+
|
| 173 |
+
MLA sub-region scaling per Algorithm 1 (MuonClip):
|
| 174 |
+
wq_b: q_nope rows β βΞ³, q_pe rows β Ξ³
|
| 175 |
+
wkv_b: k_nope rows β βΞ³, v rows β unchanged
|
| 176 |
+
"""
|
| 177 |
+
W = p.data if isinstance(p, torch.nn.Parameter) else p
|
| 178 |
+
|
| 179 |
+
if not info.is_mla:
|
| 180 |
+
# MHA/GQA: uniform βΞ³ applied to all rows in each head
|
| 181 |
+
W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
# MLA: vectorized sub-region scaling within each head
|
| 185 |
+
if info.kind == 'wq_b':
|
| 186 |
+
qk_nope = info.qk_nope_head_dim
|
| 187 |
+
qk_head_dim = qk_nope + info.qk_rope_head_dim
|
| 188 |
+
W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
|
| 189 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β βΞ³
|
| 190 |
+
W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
|
| 191 |
+
1)) # q_pe β Ξ³
|
| 192 |
+
|
| 193 |
+
elif info.kind == 'wkv_b':
|
| 194 |
+
qk_nope = info.qk_nope_head_dim
|
| 195 |
+
kv_stride = qk_nope + info.v_head_dim
|
| 196 |
+
W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
|
| 197 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β βΞ³
|
| 198 |
+
# v rows: not touched (k_R shared rotary unchanged)
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8d53b78_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8d53b78_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8d53b78_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 2004728
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:45eef069a7caa85678cd1e05f0c60c5cfbc676dc93a1bcb31e55eb34730aa469
|
| 3 |
size 2004728
|
|
@@ -93,10 +93,7 @@ class CPUOffloadPool:
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
-
cpu_flat = torch.empty(off,
|
| 97 |
-
dtype=dtype,
|
| 98 |
-
device="cpu",
|
| 99 |
-
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
@@ -140,8 +137,7 @@ class CPUOffloadPool:
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
-
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
-
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
|
@@ -159,8 +155,10 @@ class CPUOffloadPool:
|
|
| 159 |
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
-
logger.info(
|
| 163 |
-
|
|
|
|
|
|
|
| 164 |
|
| 165 |
# ------------------------------------------------------------------
|
| 166 |
def reload(self):
|
|
@@ -198,12 +196,11 @@ class CPUOffloadPool:
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
-
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
-
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
| 206 |
if not self._logged:
|
| 207 |
-
logger.info(
|
| 208 |
-
|
| 209 |
-
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
|
|
|
|
|
|
|
|
|
|
| 97 |
self._groups[dtype] = {
|
| 98 |
"indices": indices,
|
| 99 |
"offsets": offsets,
|
|
|
|
| 137 |
for i, mgd_idx in enumerate(indices):
|
| 138 |
local = self._local(self._managed[mgd_idx])
|
| 139 |
off, n = offsets[i]
|
| 140 |
+
cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
|
|
|
|
| 141 |
|
| 142 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 143 |
|
|
|
|
| 155 |
)
|
| 156 |
|
| 157 |
if not self._logged:
|
| 158 |
+
logger.info(
|
| 159 |
+
"[CPUOffload] Offloaded %.2f MB (GPU β CPU)",
|
| 160 |
+
offloaded_bytes / (1024**2),
|
| 161 |
+
)
|
| 162 |
|
| 163 |
# ------------------------------------------------------------------
|
| 164 |
def reload(self):
|
|
|
|
| 196 |
for i, mgd_idx in enumerate(indices):
|
| 197 |
local = self._local(self._managed[mgd_idx])
|
| 198 |
off, n = offsets[i]
|
| 199 |
+
local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
|
|
|
|
| 200 |
|
| 201 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 202 |
|
| 203 |
if not self._logged:
|
| 204 |
+
logger.info(
|
| 205 |
+
"[CPUOffload] Reloaded %.2f MB (CPU β GPU)", reloaded_bytes / (1024**2)
|
| 206 |
+
)
|
|
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1054 |
return super().state_dict()
|
| 1055 |
|
| 1056 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
if self.cpu_offload:
|
| 1058 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError(
|
| 1054 |
+
"Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
|
| 1055 |
+
)
|
| 1056 |
return super().state_dict()
|
| 1057 |
|
| 1058 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1059 |
if self.cpu_offload:
|
| 1060 |
+
raise RuntimeError(
|
| 1061 |
+
"Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
|
| 1062 |
+
)
|
| 1063 |
super().load_state_dict(state_dict)
|
| 1064 |
|
| 1065 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
-
LHS = np.array(
|
| 36 |
-
[
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 42 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 43 |
-
raise ValueError(
|
| 44 |
-
|
|
|
|
| 45 |
q, r = np.sqrt(
|
| 46 |
-
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
|
| 47 |
-
|
| 48 |
if not np.all(np.isfinite([q, r])):
|
| 49 |
-
raise ValueError(
|
| 50 |
-
f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
| 51 |
if abs(old_E - E) <= 1e-15:
|
| 52 |
break
|
| 53 |
else:
|
| 54 |
raise RuntimeError(
|
| 55 |
-
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
|
|
|
| 56 |
return float(a), float(b), float(c)
|
| 57 |
|
| 58 |
|
|
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
|
|
| 111 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 112 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 113 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 114 |
-
_coeffs_list = _optimal_composition(
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
cushion=0.02)
|
| 118 |
|
| 119 |
|
| 120 |
# This code is adapted from:
|
|
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 148 |
|
| 149 |
X = X / (X.norm() + 1e-7)
|
| 150 |
hs = _coeffs_list[:steps] + list(
|
| 151 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 152 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 153 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 154 |
# Perform the NS iterations
|
|
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
|
|
| 183 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 184 |
|
| 185 |
hs = _coeffs_list[:steps] + list(
|
| 186 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 187 |
for a, b, c in hs:
|
| 188 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 189 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
+
LHS = np.array(
|
| 36 |
+
[
|
| 37 |
+
[l, l**3, l**5, 1],
|
| 38 |
+
[q, q**3, q**5, -1],
|
| 39 |
+
[r, r**3, r**5, 1],
|
| 40 |
+
[u, u**3, u**5, -1],
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 44 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
|
| 47 |
+
)
|
| 48 |
q, r = np.sqrt(
|
| 49 |
+
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
|
| 50 |
+
)
|
| 51 |
if not np.all(np.isfinite([q, r])):
|
| 52 |
+
raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
|
|
|
| 53 |
if abs(old_E - E) <= 1e-15:
|
| 54 |
break
|
| 55 |
else:
|
| 56 |
raise RuntimeError(
|
| 57 |
+
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
| 58 |
+
)
|
| 59 |
return float(a), float(b), float(c)
|
| 60 |
|
| 61 |
|
|
|
|
| 114 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 115 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 116 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 117 |
+
_coeffs_list = _optimal_composition(
|
| 118 |
+
l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
|
| 119 |
+
)
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
# This code is adapted from:
|
|
|
|
| 150 |
|
| 151 |
X = X / (X.norm() + 1e-7)
|
| 152 |
hs = _coeffs_list[:steps] + list(
|
| 153 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 154 |
+
)
|
| 155 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 156 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 157 |
# Perform the NS iterations
|
|
|
|
| 186 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 187 |
|
| 188 |
hs = _coeffs_list[:steps] + list(
|
| 189 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 190 |
+
)
|
| 191 |
for a, b, c in hs:
|
| 192 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 193 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
Returns:
|
| 19 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 23 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 24 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 25 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
|
|
|
|
|
|
| 26 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 27 |
"""
|
| 28 |
parts = normalize_fqn(name).split('.')
|
|
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 37 |
layer_idx = int(part)
|
| 38 |
break
|
| 39 |
|
| 40 |
-
if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
|
| 41 |
return kind, layer_idx
|
| 42 |
|
| 43 |
return None, -1
|
|
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 46 |
@dataclass
|
| 47 |
class QKClipInfo:
|
| 48 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 49 |
-
kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
|
| 50 |
indices: list[int] # which heads to consider for clipping
|
| 51 |
-
head_dim: int # from config
|
| 52 |
threshold: float # from config
|
| 53 |
logit: torch.Tensor | None
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 57 |
"""Extract QK clipping info for a named parameter.
|
| 58 |
|
| 59 |
Args:
|
| 60 |
clip_config: QK clipping configuration dict (or None).
|
|
|
|
|
|
|
| 61 |
n: Parameter name string.
|
| 62 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 63 |
|
|
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
|
|
| 70 |
head_dim = clip_config.get('head_dim')
|
| 71 |
threshold = clip_config.get('threshold')
|
| 72 |
kind, layer_idx = parse_qk_layer(n)
|
|
|
|
| 73 |
|
| 74 |
logit, indices = None, []
|
| 75 |
if qk_logits is not None and kind is not None:
|
| 76 |
logit = qk_logits[layer_idx]
|
| 77 |
-
indices_key = 'q_indices' if 'q' in kind else 'k_indices'
|
| 78 |
-
indices = clip_config.get(indices_key, []) or []
|
| 79 |
-
|
| 80 |
if isinstance(logit, DTensor):
|
| 81 |
# In TP settings, qk_logits may be DTensor
|
| 82 |
# We convert it to full tensor here for simplicity
|
| 83 |
logit = logit.full_tensor()
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
def compute_scales(p, qk_clip_state):
|
| 95 |
"""Compute per-head scaling factors for QK clipping.
|
| 96 |
|
| 97 |
-
Returns scales tensor if any head exceeds threshold, else None.
|
|
|
|
| 98 |
"""
|
| 99 |
kind = qk_clip_state.kind
|
| 100 |
indices = qk_clip_state.indices
|
|
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
|
|
| 118 |
if not head_scales:
|
| 119 |
return None
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 123 |
for head_idx, scale in head_scales.items():
|
| 124 |
scales_full[head_idx] = scale
|
| 125 |
return scales_full
|
| 126 |
|
| 127 |
|
| 128 |
-
def qk_clip(p, scales,
|
| 129 |
-
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
+
and return (kind, layer_index).
|
| 17 |
+
|
| 18 |
+
Supported kinds:
|
| 19 |
+
MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
|
| 20 |
+
MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
|
| 21 |
|
| 22 |
Returns:
|
| 23 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
|
|
| 27 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 28 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 29 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
| 30 |
+
'model.1.attn.wq_b.weight' -> ('wq_b', 1)
|
| 31 |
+
'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
|
| 32 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 33 |
"""
|
| 34 |
parts = normalize_fqn(name).split('.')
|
|
|
|
| 43 |
layer_idx = int(part)
|
| 44 |
break
|
| 45 |
|
| 46 |
+
if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
|
| 47 |
return kind, layer_idx
|
| 48 |
|
| 49 |
return None, -1
|
|
|
|
| 52 |
@dataclass
|
| 53 |
class QKClipInfo:
|
| 54 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 55 |
+
kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
|
| 56 |
indices: list[int] # which heads to consider for clipping
|
| 57 |
+
head_dim: int # from config (qk_head_dim for MLA wq_b)
|
| 58 |
threshold: float # from config
|
| 59 |
logit: torch.Tensor | None
|
| 60 |
|
| 61 |
+
# MLA-specific fields
|
| 62 |
+
is_mla: bool = False
|
| 63 |
+
qk_nope_head_dim: int = 0
|
| 64 |
+
qk_rope_head_dim: int = 0
|
| 65 |
+
v_head_dim: int = 0
|
| 66 |
+
|
| 67 |
|
| 68 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 69 |
"""Extract QK clipping info for a named parameter.
|
| 70 |
|
| 71 |
Args:
|
| 72 |
clip_config: QK clipping configuration dict (or None).
|
| 73 |
+
MHA/GQA keys: head_dim, threshold, q_indices, k_indices
|
| 74 |
+
MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
|
| 75 |
n: Parameter name string.
|
| 76 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 77 |
|
|
|
|
| 84 |
head_dim = clip_config.get('head_dim')
|
| 85 |
threshold = clip_config.get('threshold')
|
| 86 |
kind, layer_idx = parse_qk_layer(n)
|
| 87 |
+
is_mla = clip_config.get('is_mla', False)
|
| 88 |
|
| 89 |
logit, indices = None, []
|
| 90 |
if qk_logits is not None and kind is not None:
|
| 91 |
logit = qk_logits[layer_idx]
|
|
|
|
|
|
|
|
|
|
| 92 |
if isinstance(logit, DTensor):
|
| 93 |
# In TP settings, qk_logits may be DTensor
|
| 94 |
# We convert it to full tensor here for simplicity
|
| 95 |
logit = logit.full_tensor()
|
| 96 |
|
| 97 |
+
if kind in ('wq_b', 'wq', 'q_proj'):
|
| 98 |
+
indices = clip_config.get('q_indices', []) or []
|
| 99 |
+
elif kind in ('wkv_b', 'wk', 'k_proj'):
|
| 100 |
+
indices = clip_config.get('k_indices', []) or []
|
| 101 |
+
|
| 102 |
+
if is_mla:
|
| 103 |
+
return QKClipInfo(
|
| 104 |
+
kind=kind,
|
| 105 |
+
indices=indices,
|
| 106 |
+
head_dim=head_dim,
|
| 107 |
+
threshold=threshold,
|
| 108 |
+
logit=logit,
|
| 109 |
+
is_mla=True,
|
| 110 |
+
qk_nope_head_dim=clip_config['qk_nope_head_dim'],
|
| 111 |
+
qk_rope_head_dim=clip_config['qk_rope_head_dim'],
|
| 112 |
+
v_head_dim=clip_config['v_head_dim'],
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
return QKClipInfo(
|
| 116 |
+
kind=kind,
|
| 117 |
+
indices=indices,
|
| 118 |
+
head_dim=head_dim,
|
| 119 |
+
threshold=threshold,
|
| 120 |
+
logit=logit,
|
| 121 |
+
)
|
| 122 |
|
| 123 |
|
| 124 |
def compute_scales(p, qk_clip_state):
|
| 125 |
"""Compute per-head scaling factors for QK clipping.
|
| 126 |
|
| 127 |
+
Returns scales tensor (βΞ³ per head) if any head exceeds threshold, else None.
|
| 128 |
+
For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
|
| 129 |
"""
|
| 130 |
kind = qk_clip_state.kind
|
| 131 |
indices = qk_clip_state.indices
|
|
|
|
| 149 |
if not head_scales:
|
| 150 |
return None
|
| 151 |
|
| 152 |
+
# For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
|
| 153 |
+
if qk_clip_state.is_mla and kind == 'wkv_b':
|
| 154 |
+
effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
|
| 155 |
+
else:
|
| 156 |
+
effective_head_dim = head_dim
|
| 157 |
+
|
| 158 |
+
H_global = p.shape[0] // effective_head_dim
|
| 159 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 160 |
for head_idx, scale in head_scales.items():
|
| 161 |
scales_full[head_idx] = scale
|
| 162 |
return scales_full
|
| 163 |
|
| 164 |
|
| 165 |
+
def qk_clip(p, scales, info):
|
| 166 |
+
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
p: Parameter (nn.Parameter or raw tensor).
|
| 170 |
+
scales: [n_heads] tensor, each element = βΞ³_h.
|
| 171 |
+
info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
|
| 172 |
+
|
| 173 |
+
MLA sub-region scaling per Algorithm 1 (MuonClip):
|
| 174 |
+
wq_b: q_nope rows β βΞ³, q_pe rows β Ξ³
|
| 175 |
+
wkv_b: k_nope rows β βΞ³, v rows β unchanged
|
| 176 |
+
"""
|
| 177 |
+
W = p.data if isinstance(p, torch.nn.Parameter) else p
|
| 178 |
+
|
| 179 |
+
if not info.is_mla:
|
| 180 |
+
# MHA/GQA: uniform βΞ³ applied to all rows in each head
|
| 181 |
+
W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
# MLA: vectorized sub-region scaling within each head
|
| 185 |
+
if info.kind == 'wq_b':
|
| 186 |
+
qk_nope = info.qk_nope_head_dim
|
| 187 |
+
qk_head_dim = qk_nope + info.qk_rope_head_dim
|
| 188 |
+
W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
|
| 189 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β βΞ³
|
| 190 |
+
W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
|
| 191 |
+
1)) # q_pe β Ξ³
|
| 192 |
+
|
| 193 |
+
elif info.kind == 'wkv_b':
|
| 194 |
+
qk_nope = info.qk_nope_head_dim
|
| 195 |
+
kv_stride = qk_nope + info.v_head_dim
|
| 196 |
+
W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
|
| 197 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β βΞ³
|
| 198 |
+
# v rows: not touched (k_R shared rotary unchanged)
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8d53b78_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8d53b78_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8d53b78_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1866400
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:055206c495ecade2fe4b5427db34f0a48152174e79808cbe1ce7d7ca86d32396
|
| 3 |
size 1866400
|
|
@@ -93,10 +93,7 @@ class CPUOffloadPool:
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
-
cpu_flat = torch.empty(off,
|
| 97 |
-
dtype=dtype,
|
| 98 |
-
device="cpu",
|
| 99 |
-
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
@@ -140,8 +137,7 @@ class CPUOffloadPool:
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
-
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
-
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
|
@@ -159,8 +155,10 @@ class CPUOffloadPool:
|
|
| 159 |
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
-
logger.info(
|
| 163 |
-
|
|
|
|
|
|
|
| 164 |
|
| 165 |
# ------------------------------------------------------------------
|
| 166 |
def reload(self):
|
|
@@ -198,12 +196,11 @@ class CPUOffloadPool:
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
-
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
-
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
| 206 |
if not self._logged:
|
| 207 |
-
logger.info(
|
| 208 |
-
|
| 209 |
-
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
|
|
|
|
|
|
|
|
|
|
| 97 |
self._groups[dtype] = {
|
| 98 |
"indices": indices,
|
| 99 |
"offsets": offsets,
|
|
|
|
| 137 |
for i, mgd_idx in enumerate(indices):
|
| 138 |
local = self._local(self._managed[mgd_idx])
|
| 139 |
off, n = offsets[i]
|
| 140 |
+
cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
|
|
|
|
| 141 |
|
| 142 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 143 |
|
|
|
|
| 155 |
)
|
| 156 |
|
| 157 |
if not self._logged:
|
| 158 |
+
logger.info(
|
| 159 |
+
"[CPUOffload] Offloaded %.2f MB (GPU β CPU)",
|
| 160 |
+
offloaded_bytes / (1024**2),
|
| 161 |
+
)
|
| 162 |
|
| 163 |
# ------------------------------------------------------------------
|
| 164 |
def reload(self):
|
|
|
|
| 196 |
for i, mgd_idx in enumerate(indices):
|
| 197 |
local = self._local(self._managed[mgd_idx])
|
| 198 |
off, n = offsets[i]
|
| 199 |
+
local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
|
|
|
|
| 200 |
|
| 201 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 202 |
|
| 203 |
if not self._logged:
|
| 204 |
+
logger.info(
|
| 205 |
+
"[CPUOffload] Reloaded %.2f MB (CPU β GPU)", reloaded_bytes / (1024**2)
|
| 206 |
+
)
|
|
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1054 |
return super().state_dict()
|
| 1055 |
|
| 1056 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
if self.cpu_offload:
|
| 1058 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError(
|
| 1054 |
+
"Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
|
| 1055 |
+
)
|
| 1056 |
return super().state_dict()
|
| 1057 |
|
| 1058 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1059 |
if self.cpu_offload:
|
| 1060 |
+
raise RuntimeError(
|
| 1061 |
+
"Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
|
| 1062 |
+
)
|
| 1063 |
super().load_state_dict(state_dict)
|
| 1064 |
|
| 1065 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
-
LHS = np.array(
|
| 36 |
-
[
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 42 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 43 |
-
raise ValueError(
|
| 44 |
-
|
|
|
|
| 45 |
q, r = np.sqrt(
|
| 46 |
-
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
|
| 47 |
-
|
| 48 |
if not np.all(np.isfinite([q, r])):
|
| 49 |
-
raise ValueError(
|
| 50 |
-
f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
| 51 |
if abs(old_E - E) <= 1e-15:
|
| 52 |
break
|
| 53 |
else:
|
| 54 |
raise RuntimeError(
|
| 55 |
-
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
|
|
|
| 56 |
return float(a), float(b), float(c)
|
| 57 |
|
| 58 |
|
|
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
|
|
| 111 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 112 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 113 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 114 |
-
_coeffs_list = _optimal_composition(
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
cushion=0.02)
|
| 118 |
|
| 119 |
|
| 120 |
# This code is adapted from:
|
|
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 148 |
|
| 149 |
X = X / (X.norm() + 1e-7)
|
| 150 |
hs = _coeffs_list[:steps] + list(
|
| 151 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 152 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 153 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 154 |
# Perform the NS iterations
|
|
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
|
|
| 183 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 184 |
|
| 185 |
hs = _coeffs_list[:steps] + list(
|
| 186 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 187 |
for a, b, c in hs:
|
| 188 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 189 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
+
LHS = np.array(
|
| 36 |
+
[
|
| 37 |
+
[l, l**3, l**5, 1],
|
| 38 |
+
[q, q**3, q**5, -1],
|
| 39 |
+
[r, r**3, r**5, 1],
|
| 40 |
+
[u, u**3, u**5, -1],
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 44 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
|
| 47 |
+
)
|
| 48 |
q, r = np.sqrt(
|
| 49 |
+
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
|
| 50 |
+
)
|
| 51 |
if not np.all(np.isfinite([q, r])):
|
| 52 |
+
raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
|
|
|
| 53 |
if abs(old_E - E) <= 1e-15:
|
| 54 |
break
|
| 55 |
else:
|
| 56 |
raise RuntimeError(
|
| 57 |
+
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
| 58 |
+
)
|
| 59 |
return float(a), float(b), float(c)
|
| 60 |
|
| 61 |
|
|
|
|
| 114 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 115 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 116 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 117 |
+
_coeffs_list = _optimal_composition(
|
| 118 |
+
l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
|
| 119 |
+
)
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
# This code is adapted from:
|
|
|
|
| 150 |
|
| 151 |
X = X / (X.norm() + 1e-7)
|
| 152 |
hs = _coeffs_list[:steps] + list(
|
| 153 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 154 |
+
)
|
| 155 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 156 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 157 |
# Perform the NS iterations
|
|
|
|
| 186 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 187 |
|
| 188 |
hs = _coeffs_list[:steps] + list(
|
| 189 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 190 |
+
)
|
| 191 |
for a, b, c in hs:
|
| 192 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 193 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
Returns:
|
| 19 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 23 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 24 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 25 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
|
|
|
|
|
|
| 26 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 27 |
"""
|
| 28 |
parts = normalize_fqn(name).split('.')
|
|
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 37 |
layer_idx = int(part)
|
| 38 |
break
|
| 39 |
|
| 40 |
-
if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
|
| 41 |
return kind, layer_idx
|
| 42 |
|
| 43 |
return None, -1
|
|
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 46 |
@dataclass
|
| 47 |
class QKClipInfo:
|
| 48 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 49 |
-
kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
|
| 50 |
indices: list[int] # which heads to consider for clipping
|
| 51 |
-
head_dim: int # from config
|
| 52 |
threshold: float # from config
|
| 53 |
logit: torch.Tensor | None
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 57 |
"""Extract QK clipping info for a named parameter.
|
| 58 |
|
| 59 |
Args:
|
| 60 |
clip_config: QK clipping configuration dict (or None).
|
|
|
|
|
|
|
| 61 |
n: Parameter name string.
|
| 62 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 63 |
|
|
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
|
|
| 70 |
head_dim = clip_config.get('head_dim')
|
| 71 |
threshold = clip_config.get('threshold')
|
| 72 |
kind, layer_idx = parse_qk_layer(n)
|
|
|
|
| 73 |
|
| 74 |
logit, indices = None, []
|
| 75 |
if qk_logits is not None and kind is not None:
|
| 76 |
logit = qk_logits[layer_idx]
|
| 77 |
-
indices_key = 'q_indices' if 'q' in kind else 'k_indices'
|
| 78 |
-
indices = clip_config.get(indices_key, []) or []
|
| 79 |
-
|
| 80 |
if isinstance(logit, DTensor):
|
| 81 |
# In TP settings, qk_logits may be DTensor
|
| 82 |
# We convert it to full tensor here for simplicity
|
| 83 |
logit = logit.full_tensor()
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
def compute_scales(p, qk_clip_state):
|
| 95 |
"""Compute per-head scaling factors for QK clipping.
|
| 96 |
|
| 97 |
-
Returns scales tensor if any head exceeds threshold, else None.
|
|
|
|
| 98 |
"""
|
| 99 |
kind = qk_clip_state.kind
|
| 100 |
indices = qk_clip_state.indices
|
|
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
|
|
| 118 |
if not head_scales:
|
| 119 |
return None
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 123 |
for head_idx, scale in head_scales.items():
|
| 124 |
scales_full[head_idx] = scale
|
| 125 |
return scales_full
|
| 126 |
|
| 127 |
|
| 128 |
-
def qk_clip(p, scales,
|
| 129 |
-
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
+
and return (kind, layer_index).
|
| 17 |
+
|
| 18 |
+
Supported kinds:
|
| 19 |
+
MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
|
| 20 |
+
MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
|
| 21 |
|
| 22 |
Returns:
|
| 23 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
|
|
| 27 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 28 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 29 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
| 30 |
+
'model.1.attn.wq_b.weight' -> ('wq_b', 1)
|
| 31 |
+
'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
|
| 32 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 33 |
"""
|
| 34 |
parts = normalize_fqn(name).split('.')
|
|
|
|
| 43 |
layer_idx = int(part)
|
| 44 |
break
|
| 45 |
|
| 46 |
+
if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
|
| 47 |
return kind, layer_idx
|
| 48 |
|
| 49 |
return None, -1
|
|
|
|
| 52 |
@dataclass
|
| 53 |
class QKClipInfo:
|
| 54 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 55 |
+
kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
|
| 56 |
indices: list[int] # which heads to consider for clipping
|
| 57 |
+
head_dim: int # from config (qk_head_dim for MLA wq_b)
|
| 58 |
threshold: float # from config
|
| 59 |
logit: torch.Tensor | None
|
| 60 |
|
| 61 |
+
# MLA-specific fields
|
| 62 |
+
is_mla: bool = False
|
| 63 |
+
qk_nope_head_dim: int = 0
|
| 64 |
+
qk_rope_head_dim: int = 0
|
| 65 |
+
v_head_dim: int = 0
|
| 66 |
+
|
| 67 |
|
| 68 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 69 |
"""Extract QK clipping info for a named parameter.
|
| 70 |
|
| 71 |
Args:
|
| 72 |
clip_config: QK clipping configuration dict (or None).
|
| 73 |
+
MHA/GQA keys: head_dim, threshold, q_indices, k_indices
|
| 74 |
+
MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
|
| 75 |
n: Parameter name string.
|
| 76 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 77 |
|
|
|
|
| 84 |
head_dim = clip_config.get('head_dim')
|
| 85 |
threshold = clip_config.get('threshold')
|
| 86 |
kind, layer_idx = parse_qk_layer(n)
|
| 87 |
+
is_mla = clip_config.get('is_mla', False)
|
| 88 |
|
| 89 |
logit, indices = None, []
|
| 90 |
if qk_logits is not None and kind is not None:
|
| 91 |
logit = qk_logits[layer_idx]
|
|
|
|
|
|
|
|
|
|
| 92 |
if isinstance(logit, DTensor):
|
| 93 |
# In TP settings, qk_logits may be DTensor
|
| 94 |
# We convert it to full tensor here for simplicity
|
| 95 |
logit = logit.full_tensor()
|
| 96 |
|
| 97 |
+
if kind in ('wq_b', 'wq', 'q_proj'):
|
| 98 |
+
indices = clip_config.get('q_indices', []) or []
|
| 99 |
+
elif kind in ('wkv_b', 'wk', 'k_proj'):
|
| 100 |
+
indices = clip_config.get('k_indices', []) or []
|
| 101 |
+
|
| 102 |
+
if is_mla:
|
| 103 |
+
return QKClipInfo(
|
| 104 |
+
kind=kind,
|
| 105 |
+
indices=indices,
|
| 106 |
+
head_dim=head_dim,
|
| 107 |
+
threshold=threshold,
|
| 108 |
+
logit=logit,
|
| 109 |
+
is_mla=True,
|
| 110 |
+
qk_nope_head_dim=clip_config['qk_nope_head_dim'],
|
| 111 |
+
qk_rope_head_dim=clip_config['qk_rope_head_dim'],
|
| 112 |
+
v_head_dim=clip_config['v_head_dim'],
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
return QKClipInfo(
|
| 116 |
+
kind=kind,
|
| 117 |
+
indices=indices,
|
| 118 |
+
head_dim=head_dim,
|
| 119 |
+
threshold=threshold,
|
| 120 |
+
logit=logit,
|
| 121 |
+
)
|
| 122 |
|
| 123 |
|
| 124 |
def compute_scales(p, qk_clip_state):
|
| 125 |
"""Compute per-head scaling factors for QK clipping.
|
| 126 |
|
| 127 |
+
Returns scales tensor (βΞ³ per head) if any head exceeds threshold, else None.
|
| 128 |
+
For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
|
| 129 |
"""
|
| 130 |
kind = qk_clip_state.kind
|
| 131 |
indices = qk_clip_state.indices
|
|
|
|
| 149 |
if not head_scales:
|
| 150 |
return None
|
| 151 |
|
| 152 |
+
# For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
|
| 153 |
+
if qk_clip_state.is_mla and kind == 'wkv_b':
|
| 154 |
+
effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
|
| 155 |
+
else:
|
| 156 |
+
effective_head_dim = head_dim
|
| 157 |
+
|
| 158 |
+
H_global = p.shape[0] // effective_head_dim
|
| 159 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 160 |
for head_idx, scale in head_scales.items():
|
| 161 |
scales_full[head_idx] = scale
|
| 162 |
return scales_full
|
| 163 |
|
| 164 |
|
| 165 |
+
def qk_clip(p, scales, info):
|
| 166 |
+
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
p: Parameter (nn.Parameter or raw tensor).
|
| 170 |
+
scales: [n_heads] tensor, each element = βΞ³_h.
|
| 171 |
+
info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
|
| 172 |
+
|
| 173 |
+
MLA sub-region scaling per Algorithm 1 (MuonClip):
|
| 174 |
+
wq_b: q_nope rows β βΞ³, q_pe rows β Ξ³
|
| 175 |
+
wkv_b: k_nope rows β βΞ³, v rows β unchanged
|
| 176 |
+
"""
|
| 177 |
+
W = p.data if isinstance(p, torch.nn.Parameter) else p
|
| 178 |
+
|
| 179 |
+
if not info.is_mla:
|
| 180 |
+
# MHA/GQA: uniform βΞ³ applied to all rows in each head
|
| 181 |
+
W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
# MLA: vectorized sub-region scaling within each head
|
| 185 |
+
if info.kind == 'wq_b':
|
| 186 |
+
qk_nope = info.qk_nope_head_dim
|
| 187 |
+
qk_head_dim = qk_nope + info.qk_rope_head_dim
|
| 188 |
+
W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
|
| 189 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β βΞ³
|
| 190 |
+
W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
|
| 191 |
+
1)) # q_pe β Ξ³
|
| 192 |
+
|
| 193 |
+
elif info.kind == 'wkv_b':
|
| 194 |
+
qk_nope = info.qk_nope_head_dim
|
| 195 |
+
kv_stride = qk_nope + info.v_head_dim
|
| 196 |
+
W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
|
| 197 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β βΞ³
|
| 198 |
+
# v rows: not touched (k_R shared rotary unchanged)
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8d53b78_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8d53b78_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8d53b78_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1866112
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:315ff09ffa88ec806cb8abe49edb2ca6951e9ac34be3d3e10f159093f9576ee0
|
| 3 |
size 1866112
|
|
@@ -93,10 +93,7 @@ class CPUOffloadPool:
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
-
cpu_flat = torch.empty(off,
|
| 97 |
-
dtype=dtype,
|
| 98 |
-
device="cpu",
|
| 99 |
-
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
@@ -140,8 +137,7 @@ class CPUOffloadPool:
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
-
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
-
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
|
@@ -159,8 +155,10 @@ class CPUOffloadPool:
|
|
| 159 |
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
-
logger.info(
|
| 163 |
-
|
|
|
|
|
|
|
| 164 |
|
| 165 |
# ------------------------------------------------------------------
|
| 166 |
def reload(self):
|
|
@@ -198,12 +196,11 @@ class CPUOffloadPool:
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
-
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
-
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
| 206 |
if not self._logged:
|
| 207 |
-
logger.info(
|
| 208 |
-
|
| 209 |
-
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
|
|
|
|
|
|
|
|
|
|
| 97 |
self._groups[dtype] = {
|
| 98 |
"indices": indices,
|
| 99 |
"offsets": offsets,
|
|
|
|
| 137 |
for i, mgd_idx in enumerate(indices):
|
| 138 |
local = self._local(self._managed[mgd_idx])
|
| 139 |
off, n = offsets[i]
|
| 140 |
+
cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
|
|
|
|
| 141 |
|
| 142 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 143 |
|
|
|
|
| 155 |
)
|
| 156 |
|
| 157 |
if not self._logged:
|
| 158 |
+
logger.info(
|
| 159 |
+
"[CPUOffload] Offloaded %.2f MB (GPU β CPU)",
|
| 160 |
+
offloaded_bytes / (1024**2),
|
| 161 |
+
)
|
| 162 |
|
| 163 |
# ------------------------------------------------------------------
|
| 164 |
def reload(self):
|
|
|
|
| 196 |
for i, mgd_idx in enumerate(indices):
|
| 197 |
local = self._local(self._managed[mgd_idx])
|
| 198 |
off, n = offsets[i]
|
| 199 |
+
local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
|
|
|
|
| 200 |
|
| 201 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 202 |
|
| 203 |
if not self._logged:
|
| 204 |
+
logger.info(
|
| 205 |
+
"[CPUOffload] Reloaded %.2f MB (CPU β GPU)", reloaded_bytes / (1024**2)
|
| 206 |
+
)
|
|
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1054 |
return super().state_dict()
|
| 1055 |
|
| 1056 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
if self.cpu_offload:
|
| 1058 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError(
|
| 1054 |
+
"Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
|
| 1055 |
+
)
|
| 1056 |
return super().state_dict()
|
| 1057 |
|
| 1058 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1059 |
if self.cpu_offload:
|
| 1060 |
+
raise RuntimeError(
|
| 1061 |
+
"Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
|
| 1062 |
+
)
|
| 1063 |
super().load_state_dict(state_dict)
|
| 1064 |
|
| 1065 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
-
LHS = np.array(
|
| 36 |
-
[
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 42 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 43 |
-
raise ValueError(
|
| 44 |
-
|
|
|
|
| 45 |
q, r = np.sqrt(
|
| 46 |
-
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
|
| 47 |
-
|
| 48 |
if not np.all(np.isfinite([q, r])):
|
| 49 |
-
raise ValueError(
|
| 50 |
-
f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
| 51 |
if abs(old_E - E) <= 1e-15:
|
| 52 |
break
|
| 53 |
else:
|
| 54 |
raise RuntimeError(
|
| 55 |
-
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
|
|
|
| 56 |
return float(a), float(b), float(c)
|
| 57 |
|
| 58 |
|
|
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
|
|
| 111 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 112 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 113 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 114 |
-
_coeffs_list = _optimal_composition(
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
cushion=0.02)
|
| 118 |
|
| 119 |
|
| 120 |
# This code is adapted from:
|
|
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 148 |
|
| 149 |
X = X / (X.norm() + 1e-7)
|
| 150 |
hs = _coeffs_list[:steps] + list(
|
| 151 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 152 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 153 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 154 |
# Perform the NS iterations
|
|
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
|
|
| 183 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 184 |
|
| 185 |
hs = _coeffs_list[:steps] + list(
|
| 186 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 187 |
for a, b, c in hs:
|
| 188 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 189 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
+
LHS = np.array(
|
| 36 |
+
[
|
| 37 |
+
[l, l**3, l**5, 1],
|
| 38 |
+
[q, q**3, q**5, -1],
|
| 39 |
+
[r, r**3, r**5, 1],
|
| 40 |
+
[u, u**3, u**5, -1],
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 44 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
|
| 47 |
+
)
|
| 48 |
q, r = np.sqrt(
|
| 49 |
+
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
|
| 50 |
+
)
|
| 51 |
if not np.all(np.isfinite([q, r])):
|
| 52 |
+
raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
|
|
|
| 53 |
if abs(old_E - E) <= 1e-15:
|
| 54 |
break
|
| 55 |
else:
|
| 56 |
raise RuntimeError(
|
| 57 |
+
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
| 58 |
+
)
|
| 59 |
return float(a), float(b), float(c)
|
| 60 |
|
| 61 |
|
|
|
|
| 114 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 115 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 116 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 117 |
+
_coeffs_list = _optimal_composition(
|
| 118 |
+
l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
|
| 119 |
+
)
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
# This code is adapted from:
|
|
|
|
| 150 |
|
| 151 |
X = X / (X.norm() + 1e-7)
|
| 152 |
hs = _coeffs_list[:steps] + list(
|
| 153 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 154 |
+
)
|
| 155 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 156 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 157 |
# Perform the NS iterations
|
|
|
|
| 186 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 187 |
|
| 188 |
hs = _coeffs_list[:steps] + list(
|
| 189 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 190 |
+
)
|
| 191 |
for a, b, c in hs:
|
| 192 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 193 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
Returns:
|
| 19 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 23 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 24 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 25 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
|
|
|
|
|
|
| 26 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 27 |
"""
|
| 28 |
parts = normalize_fqn(name).split('.')
|
|
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 37 |
layer_idx = int(part)
|
| 38 |
break
|
| 39 |
|
| 40 |
-
if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
|
| 41 |
return kind, layer_idx
|
| 42 |
|
| 43 |
return None, -1
|
|
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 46 |
@dataclass
|
| 47 |
class QKClipInfo:
|
| 48 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 49 |
-
kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
|
| 50 |
indices: list[int] # which heads to consider for clipping
|
| 51 |
-
head_dim: int # from config
|
| 52 |
threshold: float # from config
|
| 53 |
logit: torch.Tensor | None
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 57 |
"""Extract QK clipping info for a named parameter.
|
| 58 |
|
| 59 |
Args:
|
| 60 |
clip_config: QK clipping configuration dict (or None).
|
|
|
|
|
|
|
| 61 |
n: Parameter name string.
|
| 62 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 63 |
|
|
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
|
|
| 70 |
head_dim = clip_config.get('head_dim')
|
| 71 |
threshold = clip_config.get('threshold')
|
| 72 |
kind, layer_idx = parse_qk_layer(n)
|
|
|
|
| 73 |
|
| 74 |
logit, indices = None, []
|
| 75 |
if qk_logits is not None and kind is not None:
|
| 76 |
logit = qk_logits[layer_idx]
|
| 77 |
-
indices_key = 'q_indices' if 'q' in kind else 'k_indices'
|
| 78 |
-
indices = clip_config.get(indices_key, []) or []
|
| 79 |
-
|
| 80 |
if isinstance(logit, DTensor):
|
| 81 |
# In TP settings, qk_logits may be DTensor
|
| 82 |
# We convert it to full tensor here for simplicity
|
| 83 |
logit = logit.full_tensor()
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
def compute_scales(p, qk_clip_state):
|
| 95 |
"""Compute per-head scaling factors for QK clipping.
|
| 96 |
|
| 97 |
-
Returns scales tensor if any head exceeds threshold, else None.
|
|
|
|
| 98 |
"""
|
| 99 |
kind = qk_clip_state.kind
|
| 100 |
indices = qk_clip_state.indices
|
|
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
|
|
| 118 |
if not head_scales:
|
| 119 |
return None
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 123 |
for head_idx, scale in head_scales.items():
|
| 124 |
scales_full[head_idx] = scale
|
| 125 |
return scales_full
|
| 126 |
|
| 127 |
|
| 128 |
-
def qk_clip(p, scales,
|
| 129 |
-
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
+
and return (kind, layer_index).
|
| 17 |
+
|
| 18 |
+
Supported kinds:
|
| 19 |
+
MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
|
| 20 |
+
MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
|
| 21 |
|
| 22 |
Returns:
|
| 23 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
|
|
| 27 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 28 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 29 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
| 30 |
+
'model.1.attn.wq_b.weight' -> ('wq_b', 1)
|
| 31 |
+
'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
|
| 32 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 33 |
"""
|
| 34 |
parts = normalize_fqn(name).split('.')
|
|
|
|
| 43 |
layer_idx = int(part)
|
| 44 |
break
|
| 45 |
|
| 46 |
+
if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
|
| 47 |
return kind, layer_idx
|
| 48 |
|
| 49 |
return None, -1
|
|
|
|
| 52 |
@dataclass
|
| 53 |
class QKClipInfo:
|
| 54 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 55 |
+
kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
|
| 56 |
indices: list[int] # which heads to consider for clipping
|
| 57 |
+
head_dim: int # from config (qk_head_dim for MLA wq_b)
|
| 58 |
threshold: float # from config
|
| 59 |
logit: torch.Tensor | None
|
| 60 |
|
| 61 |
+
# MLA-specific fields
|
| 62 |
+
is_mla: bool = False
|
| 63 |
+
qk_nope_head_dim: int = 0
|
| 64 |
+
qk_rope_head_dim: int = 0
|
| 65 |
+
v_head_dim: int = 0
|
| 66 |
+
|
| 67 |
|
| 68 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 69 |
"""Extract QK clipping info for a named parameter.
|
| 70 |
|
| 71 |
Args:
|
| 72 |
clip_config: QK clipping configuration dict (or None).
|
| 73 |
+
MHA/GQA keys: head_dim, threshold, q_indices, k_indices
|
| 74 |
+
MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
|
| 75 |
n: Parameter name string.
|
| 76 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 77 |
|
|
|
|
| 84 |
head_dim = clip_config.get('head_dim')
|
| 85 |
threshold = clip_config.get('threshold')
|
| 86 |
kind, layer_idx = parse_qk_layer(n)
|
| 87 |
+
is_mla = clip_config.get('is_mla', False)
|
| 88 |
|
| 89 |
logit, indices = None, []
|
| 90 |
if qk_logits is not None and kind is not None:
|
| 91 |
logit = qk_logits[layer_idx]
|
|
|
|
|
|
|
|
|
|
| 92 |
if isinstance(logit, DTensor):
|
| 93 |
# In TP settings, qk_logits may be DTensor
|
| 94 |
# We convert it to full tensor here for simplicity
|
| 95 |
logit = logit.full_tensor()
|
| 96 |
|
| 97 |
+
if kind in ('wq_b', 'wq', 'q_proj'):
|
| 98 |
+
indices = clip_config.get('q_indices', []) or []
|
| 99 |
+
elif kind in ('wkv_b', 'wk', 'k_proj'):
|
| 100 |
+
indices = clip_config.get('k_indices', []) or []
|
| 101 |
+
|
| 102 |
+
if is_mla:
|
| 103 |
+
return QKClipInfo(
|
| 104 |
+
kind=kind,
|
| 105 |
+
indices=indices,
|
| 106 |
+
head_dim=head_dim,
|
| 107 |
+
threshold=threshold,
|
| 108 |
+
logit=logit,
|
| 109 |
+
is_mla=True,
|
| 110 |
+
qk_nope_head_dim=clip_config['qk_nope_head_dim'],
|
| 111 |
+
qk_rope_head_dim=clip_config['qk_rope_head_dim'],
|
| 112 |
+
v_head_dim=clip_config['v_head_dim'],
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
return QKClipInfo(
|
| 116 |
+
kind=kind,
|
| 117 |
+
indices=indices,
|
| 118 |
+
head_dim=head_dim,
|
| 119 |
+
threshold=threshold,
|
| 120 |
+
logit=logit,
|
| 121 |
+
)
|
| 122 |
|
| 123 |
|
| 124 |
def compute_scales(p, qk_clip_state):
|
| 125 |
"""Compute per-head scaling factors for QK clipping.
|
| 126 |
|
| 127 |
+
Returns scales tensor (βΞ³ per head) if any head exceeds threshold, else None.
|
| 128 |
+
For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
|
| 129 |
"""
|
| 130 |
kind = qk_clip_state.kind
|
| 131 |
indices = qk_clip_state.indices
|
|
|
|
| 149 |
if not head_scales:
|
| 150 |
return None
|
| 151 |
|
| 152 |
+
# For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
|
| 153 |
+
if qk_clip_state.is_mla and kind == 'wkv_b':
|
| 154 |
+
effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
|
| 155 |
+
else:
|
| 156 |
+
effective_head_dim = head_dim
|
| 157 |
+
|
| 158 |
+
H_global = p.shape[0] // effective_head_dim
|
| 159 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 160 |
for head_idx, scale in head_scales.items():
|
| 161 |
scales_full[head_idx] = scale
|
| 162 |
return scales_full
|
| 163 |
|
| 164 |
|
| 165 |
+
def qk_clip(p, scales, info):
|
| 166 |
+
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
p: Parameter (nn.Parameter or raw tensor).
|
| 170 |
+
scales: [n_heads] tensor, each element = βΞ³_h.
|
| 171 |
+
info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
|
| 172 |
+
|
| 173 |
+
MLA sub-region scaling per Algorithm 1 (MuonClip):
|
| 174 |
+
wq_b: q_nope rows β βΞ³, q_pe rows β Ξ³
|
| 175 |
+
wkv_b: k_nope rows β βΞ³, v rows β unchanged
|
| 176 |
+
"""
|
| 177 |
+
W = p.data if isinstance(p, torch.nn.Parameter) else p
|
| 178 |
+
|
| 179 |
+
if not info.is_mla:
|
| 180 |
+
# MHA/GQA: uniform βΞ³ applied to all rows in each head
|
| 181 |
+
W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
# MLA: vectorized sub-region scaling within each head
|
| 185 |
+
if info.kind == 'wq_b':
|
| 186 |
+
qk_nope = info.qk_nope_head_dim
|
| 187 |
+
qk_head_dim = qk_nope + info.qk_rope_head_dim
|
| 188 |
+
W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
|
| 189 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β βΞ³
|
| 190 |
+
W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
|
| 191 |
+
1)) # q_pe β Ξ³
|
| 192 |
+
|
| 193 |
+
elif info.kind == 'wkv_b':
|
| 194 |
+
qk_nope = info.qk_nope_head_dim
|
| 195 |
+
kv_stride = qk_nope + info.v_head_dim
|
| 196 |
+
W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
|
| 197 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β βΞ³
|
| 198 |
+
# v rows: not touched (k_R shared rotary unchanged)
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8d53b78_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8d53b78_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8d53b78_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1936664
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a9a7c1beffbad405ef7d6f46f44cf9c6671d119e04a340b54c8f4c8f9d699caf
|
| 3 |
size 1936664
|
|
@@ -93,10 +93,7 @@ class CPUOffloadPool:
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
-
cpu_flat = torch.empty(off,
|
| 97 |
-
dtype=dtype,
|
| 98 |
-
device="cpu",
|
| 99 |
-
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
@@ -140,8 +137,7 @@ class CPUOffloadPool:
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
-
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
-
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
|
@@ -159,8 +155,10 @@ class CPUOffloadPool:
|
|
| 159 |
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
-
logger.info(
|
| 163 |
-
|
|
|
|
|
|
|
| 164 |
|
| 165 |
# ------------------------------------------------------------------
|
| 166 |
def reload(self):
|
|
@@ -198,12 +196,11 @@ class CPUOffloadPool:
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
-
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
-
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
| 206 |
if not self._logged:
|
| 207 |
-
logger.info(
|
| 208 |
-
|
| 209 |
-
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
|
|
|
|
|
|
|
|
|
|
| 97 |
self._groups[dtype] = {
|
| 98 |
"indices": indices,
|
| 99 |
"offsets": offsets,
|
|
|
|
| 137 |
for i, mgd_idx in enumerate(indices):
|
| 138 |
local = self._local(self._managed[mgd_idx])
|
| 139 |
off, n = offsets[i]
|
| 140 |
+
cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
|
|
|
|
| 141 |
|
| 142 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 143 |
|
|
|
|
| 155 |
)
|
| 156 |
|
| 157 |
if not self._logged:
|
| 158 |
+
logger.info(
|
| 159 |
+
"[CPUOffload] Offloaded %.2f MB (GPU β CPU)",
|
| 160 |
+
offloaded_bytes / (1024**2),
|
| 161 |
+
)
|
| 162 |
|
| 163 |
# ------------------------------------------------------------------
|
| 164 |
def reload(self):
|
|
|
|
| 196 |
for i, mgd_idx in enumerate(indices):
|
| 197 |
local = self._local(self._managed[mgd_idx])
|
| 198 |
off, n = offsets[i]
|
| 199 |
+
local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
|
|
|
|
| 200 |
|
| 201 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 202 |
|
| 203 |
if not self._logged:
|
| 204 |
+
logger.info(
|
| 205 |
+
"[CPUOffload] Reloaded %.2f MB (CPU β GPU)", reloaded_bytes / (1024**2)
|
| 206 |
+
)
|
|
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1054 |
return super().state_dict()
|
| 1055 |
|
| 1056 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
if self.cpu_offload:
|
| 1058 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError(
|
| 1054 |
+
"Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
|
| 1055 |
+
)
|
| 1056 |
return super().state_dict()
|
| 1057 |
|
| 1058 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1059 |
if self.cpu_offload:
|
| 1060 |
+
raise RuntimeError(
|
| 1061 |
+
"Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
|
| 1062 |
+
)
|
| 1063 |
super().load_state_dict(state_dict)
|
| 1064 |
|
| 1065 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
-
LHS = np.array(
|
| 36 |
-
[
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 42 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 43 |
-
raise ValueError(
|
| 44 |
-
|
|
|
|
| 45 |
q, r = np.sqrt(
|
| 46 |
-
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
|
| 47 |
-
|
| 48 |
if not np.all(np.isfinite([q, r])):
|
| 49 |
-
raise ValueError(
|
| 50 |
-
f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
| 51 |
if abs(old_E - E) <= 1e-15:
|
| 52 |
break
|
| 53 |
else:
|
| 54 |
raise RuntimeError(
|
| 55 |
-
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
|
|
|
| 56 |
return float(a), float(b), float(c)
|
| 57 |
|
| 58 |
|
|
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
|
|
| 111 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 112 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 113 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 114 |
-
_coeffs_list = _optimal_composition(
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
cushion=0.02)
|
| 118 |
|
| 119 |
|
| 120 |
# This code is adapted from:
|
|
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 148 |
|
| 149 |
X = X / (X.norm() + 1e-7)
|
| 150 |
hs = _coeffs_list[:steps] + list(
|
| 151 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 152 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 153 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 154 |
# Perform the NS iterations
|
|
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
|
|
| 183 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 184 |
|
| 185 |
hs = _coeffs_list[:steps] + list(
|
| 186 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 187 |
for a, b, c in hs:
|
| 188 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 189 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
+
LHS = np.array(
|
| 36 |
+
[
|
| 37 |
+
[l, l**3, l**5, 1],
|
| 38 |
+
[q, q**3, q**5, -1],
|
| 39 |
+
[r, r**3, r**5, 1],
|
| 40 |
+
[u, u**3, u**5, -1],
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 44 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
|
| 47 |
+
)
|
| 48 |
q, r = np.sqrt(
|
| 49 |
+
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
|
| 50 |
+
)
|
| 51 |
if not np.all(np.isfinite([q, r])):
|
| 52 |
+
raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
|
|
|
| 53 |
if abs(old_E - E) <= 1e-15:
|
| 54 |
break
|
| 55 |
else:
|
| 56 |
raise RuntimeError(
|
| 57 |
+
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
| 58 |
+
)
|
| 59 |
return float(a), float(b), float(c)
|
| 60 |
|
| 61 |
|
|
|
|
| 114 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 115 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 116 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 117 |
+
_coeffs_list = _optimal_composition(
|
| 118 |
+
l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
|
| 119 |
+
)
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
# This code is adapted from:
|
|
|
|
| 150 |
|
| 151 |
X = X / (X.norm() + 1e-7)
|
| 152 |
hs = _coeffs_list[:steps] + list(
|
| 153 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 154 |
+
)
|
| 155 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 156 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 157 |
# Perform the NS iterations
|
|
|
|
| 186 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 187 |
|
| 188 |
hs = _coeffs_list[:steps] + list(
|
| 189 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 190 |
+
)
|
| 191 |
for a, b, c in hs:
|
| 192 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 193 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
Returns:
|
| 19 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 23 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 24 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 25 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
|
|
|
|
|
|
| 26 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 27 |
"""
|
| 28 |
parts = normalize_fqn(name).split('.')
|
|
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 37 |
layer_idx = int(part)
|
| 38 |
break
|
| 39 |
|
| 40 |
-
if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
|
| 41 |
return kind, layer_idx
|
| 42 |
|
| 43 |
return None, -1
|
|
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 46 |
@dataclass
|
| 47 |
class QKClipInfo:
|
| 48 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 49 |
-
kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
|
| 50 |
indices: list[int] # which heads to consider for clipping
|
| 51 |
-
head_dim: int # from config
|
| 52 |
threshold: float # from config
|
| 53 |
logit: torch.Tensor | None
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 57 |
"""Extract QK clipping info for a named parameter.
|
| 58 |
|
| 59 |
Args:
|
| 60 |
clip_config: QK clipping configuration dict (or None).
|
|
|
|
|
|
|
| 61 |
n: Parameter name string.
|
| 62 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 63 |
|
|
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
|
|
| 70 |
head_dim = clip_config.get('head_dim')
|
| 71 |
threshold = clip_config.get('threshold')
|
| 72 |
kind, layer_idx = parse_qk_layer(n)
|
|
|
|
| 73 |
|
| 74 |
logit, indices = None, []
|
| 75 |
if qk_logits is not None and kind is not None:
|
| 76 |
logit = qk_logits[layer_idx]
|
| 77 |
-
indices_key = 'q_indices' if 'q' in kind else 'k_indices'
|
| 78 |
-
indices = clip_config.get(indices_key, []) or []
|
| 79 |
-
|
| 80 |
if isinstance(logit, DTensor):
|
| 81 |
# In TP settings, qk_logits may be DTensor
|
| 82 |
# We convert it to full tensor here for simplicity
|
| 83 |
logit = logit.full_tensor()
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
def compute_scales(p, qk_clip_state):
|
| 95 |
"""Compute per-head scaling factors for QK clipping.
|
| 96 |
|
| 97 |
-
Returns scales tensor if any head exceeds threshold, else None.
|
|
|
|
| 98 |
"""
|
| 99 |
kind = qk_clip_state.kind
|
| 100 |
indices = qk_clip_state.indices
|
|
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
|
|
| 118 |
if not head_scales:
|
| 119 |
return None
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 123 |
for head_idx, scale in head_scales.items():
|
| 124 |
scales_full[head_idx] = scale
|
| 125 |
return scales_full
|
| 126 |
|
| 127 |
|
| 128 |
-
def qk_clip(p, scales,
|
| 129 |
-
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
+
and return (kind, layer_index).
|
| 17 |
+
|
| 18 |
+
Supported kinds:
|
| 19 |
+
MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
|
| 20 |
+
MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
|
| 21 |
|
| 22 |
Returns:
|
| 23 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
|
|
| 27 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 28 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 29 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
| 30 |
+
'model.1.attn.wq_b.weight' -> ('wq_b', 1)
|
| 31 |
+
'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
|
| 32 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 33 |
"""
|
| 34 |
parts = normalize_fqn(name).split('.')
|
|
|
|
| 43 |
layer_idx = int(part)
|
| 44 |
break
|
| 45 |
|
| 46 |
+
if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
|
| 47 |
return kind, layer_idx
|
| 48 |
|
| 49 |
return None, -1
|
|
|
|
| 52 |
@dataclass
|
| 53 |
class QKClipInfo:
|
| 54 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 55 |
+
kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
|
| 56 |
indices: list[int] # which heads to consider for clipping
|
| 57 |
+
head_dim: int # from config (qk_head_dim for MLA wq_b)
|
| 58 |
threshold: float # from config
|
| 59 |
logit: torch.Tensor | None
|
| 60 |
|
| 61 |
+
# MLA-specific fields
|
| 62 |
+
is_mla: bool = False
|
| 63 |
+
qk_nope_head_dim: int = 0
|
| 64 |
+
qk_rope_head_dim: int = 0
|
| 65 |
+
v_head_dim: int = 0
|
| 66 |
+
|
| 67 |
|
| 68 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 69 |
"""Extract QK clipping info for a named parameter.
|
| 70 |
|
| 71 |
Args:
|
| 72 |
clip_config: QK clipping configuration dict (or None).
|
| 73 |
+
MHA/GQA keys: head_dim, threshold, q_indices, k_indices
|
| 74 |
+
MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
|
| 75 |
n: Parameter name string.
|
| 76 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 77 |
|
|
|
|
| 84 |
head_dim = clip_config.get('head_dim')
|
| 85 |
threshold = clip_config.get('threshold')
|
| 86 |
kind, layer_idx = parse_qk_layer(n)
|
| 87 |
+
is_mla = clip_config.get('is_mla', False)
|
| 88 |
|
| 89 |
logit, indices = None, []
|
| 90 |
if qk_logits is not None and kind is not None:
|
| 91 |
logit = qk_logits[layer_idx]
|
|
|
|
|
|
|
|
|
|
| 92 |
if isinstance(logit, DTensor):
|
| 93 |
# In TP settings, qk_logits may be DTensor
|
| 94 |
# We convert it to full tensor here for simplicity
|
| 95 |
logit = logit.full_tensor()
|
| 96 |
|
| 97 |
+
if kind in ('wq_b', 'wq', 'q_proj'):
|
| 98 |
+
indices = clip_config.get('q_indices', []) or []
|
| 99 |
+
elif kind in ('wkv_b', 'wk', 'k_proj'):
|
| 100 |
+
indices = clip_config.get('k_indices', []) or []
|
| 101 |
+
|
| 102 |
+
if is_mla:
|
| 103 |
+
return QKClipInfo(
|
| 104 |
+
kind=kind,
|
| 105 |
+
indices=indices,
|
| 106 |
+
head_dim=head_dim,
|
| 107 |
+
threshold=threshold,
|
| 108 |
+
logit=logit,
|
| 109 |
+
is_mla=True,
|
| 110 |
+
qk_nope_head_dim=clip_config['qk_nope_head_dim'],
|
| 111 |
+
qk_rope_head_dim=clip_config['qk_rope_head_dim'],
|
| 112 |
+
v_head_dim=clip_config['v_head_dim'],
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
return QKClipInfo(
|
| 116 |
+
kind=kind,
|
| 117 |
+
indices=indices,
|
| 118 |
+
head_dim=head_dim,
|
| 119 |
+
threshold=threshold,
|
| 120 |
+
logit=logit,
|
| 121 |
+
)
|
| 122 |
|
| 123 |
|
| 124 |
def compute_scales(p, qk_clip_state):
|
| 125 |
"""Compute per-head scaling factors for QK clipping.
|
| 126 |
|
| 127 |
+
Returns scales tensor (βΞ³ per head) if any head exceeds threshold, else None.
|
| 128 |
+
For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
|
| 129 |
"""
|
| 130 |
kind = qk_clip_state.kind
|
| 131 |
indices = qk_clip_state.indices
|
|
|
|
| 149 |
if not head_scales:
|
| 150 |
return None
|
| 151 |
|
| 152 |
+
# For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
|
| 153 |
+
if qk_clip_state.is_mla and kind == 'wkv_b':
|
| 154 |
+
effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
|
| 155 |
+
else:
|
| 156 |
+
effective_head_dim = head_dim
|
| 157 |
+
|
| 158 |
+
H_global = p.shape[0] // effective_head_dim
|
| 159 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 160 |
for head_idx, scale in head_scales.items():
|
| 161 |
scales_full[head_idx] = scale
|
| 162 |
return scales_full
|
| 163 |
|
| 164 |
|
| 165 |
+
def qk_clip(p, scales, info):
|
| 166 |
+
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
p: Parameter (nn.Parameter or raw tensor).
|
| 170 |
+
scales: [n_heads] tensor, each element = βΞ³_h.
|
| 171 |
+
info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
|
| 172 |
+
|
| 173 |
+
MLA sub-region scaling per Algorithm 1 (MuonClip):
|
| 174 |
+
wq_b: q_nope rows β βΞ³, q_pe rows β Ξ³
|
| 175 |
+
wkv_b: k_nope rows β βΞ³, v rows β unchanged
|
| 176 |
+
"""
|
| 177 |
+
W = p.data if isinstance(p, torch.nn.Parameter) else p
|
| 178 |
+
|
| 179 |
+
if not info.is_mla:
|
| 180 |
+
# MHA/GQA: uniform βΞ³ applied to all rows in each head
|
| 181 |
+
W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
# MLA: vectorized sub-region scaling within each head
|
| 185 |
+
if info.kind == 'wq_b':
|
| 186 |
+
qk_nope = info.qk_nope_head_dim
|
| 187 |
+
qk_head_dim = qk_nope + info.qk_rope_head_dim
|
| 188 |
+
W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
|
| 189 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β βΞ³
|
| 190 |
+
W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
|
| 191 |
+
1)) # q_pe β Ξ³
|
| 192 |
+
|
| 193 |
+
elif info.kind == 'wkv_b':
|
| 194 |
+
qk_nope = info.qk_nope_head_dim
|
| 195 |
+
kv_stride = qk_nope + info.v_head_dim
|
| 196 |
+
W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
|
| 197 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β βΞ³
|
| 198 |
+
# v rows: not touched (k_R shared rotary unchanged)
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8d53b78_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8d53b78_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8d53b78_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1999872
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:090f5a44cdfa4554147159cc36bb7e8ee9dba1ffb1fea4825aa838461fdaddf9
|
| 3 |
size 1999872
|
|
@@ -93,10 +93,7 @@ class CPUOffloadPool:
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
-
cpu_flat = torch.empty(off,
|
| 97 |
-
dtype=dtype,
|
| 98 |
-
device="cpu",
|
| 99 |
-
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
@@ -140,8 +137,7 @@ class CPUOffloadPool:
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
-
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
-
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
|
@@ -159,8 +155,10 @@ class CPUOffloadPool:
|
|
| 159 |
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
-
logger.info(
|
| 163 |
-
|
|
|
|
|
|
|
| 164 |
|
| 165 |
# ------------------------------------------------------------------
|
| 166 |
def reload(self):
|
|
@@ -198,12 +196,11 @@ class CPUOffloadPool:
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
-
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
-
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
| 206 |
if not self._logged:
|
| 207 |
-
logger.info(
|
| 208 |
-
|
| 209 |
-
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
|
|
|
|
|
|
|
|
|
|
| 97 |
self._groups[dtype] = {
|
| 98 |
"indices": indices,
|
| 99 |
"offsets": offsets,
|
|
|
|
| 137 |
for i, mgd_idx in enumerate(indices):
|
| 138 |
local = self._local(self._managed[mgd_idx])
|
| 139 |
off, n = offsets[i]
|
| 140 |
+
cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
|
|
|
|
| 141 |
|
| 142 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 143 |
|
|
|
|
| 155 |
)
|
| 156 |
|
| 157 |
if not self._logged:
|
| 158 |
+
logger.info(
|
| 159 |
+
"[CPUOffload] Offloaded %.2f MB (GPU β CPU)",
|
| 160 |
+
offloaded_bytes / (1024**2),
|
| 161 |
+
)
|
| 162 |
|
| 163 |
# ------------------------------------------------------------------
|
| 164 |
def reload(self):
|
|
|
|
| 196 |
for i, mgd_idx in enumerate(indices):
|
| 197 |
local = self._local(self._managed[mgd_idx])
|
| 198 |
off, n = offsets[i]
|
| 199 |
+
local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
|
|
|
|
| 200 |
|
| 201 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 202 |
|
| 203 |
if not self._logged:
|
| 204 |
+
logger.info(
|
| 205 |
+
"[CPUOffload] Reloaded %.2f MB (CPU β GPU)", reloaded_bytes / (1024**2)
|
| 206 |
+
)
|
|
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1054 |
return super().state_dict()
|
| 1055 |
|
| 1056 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
if self.cpu_offload:
|
| 1058 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError(
|
| 1054 |
+
"Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
|
| 1055 |
+
)
|
| 1056 |
return super().state_dict()
|
| 1057 |
|
| 1058 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1059 |
if self.cpu_offload:
|
| 1060 |
+
raise RuntimeError(
|
| 1061 |
+
"Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
|
| 1062 |
+
)
|
| 1063 |
super().load_state_dict(state_dict)
|
| 1064 |
|
| 1065 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
-
LHS = np.array(
|
| 36 |
-
[
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 42 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 43 |
-
raise ValueError(
|
| 44 |
-
|
|
|
|
| 45 |
q, r = np.sqrt(
|
| 46 |
-
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
|
| 47 |
-
|
| 48 |
if not np.all(np.isfinite([q, r])):
|
| 49 |
-
raise ValueError(
|
| 50 |
-
f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
| 51 |
if abs(old_E - E) <= 1e-15:
|
| 52 |
break
|
| 53 |
else:
|
| 54 |
raise RuntimeError(
|
| 55 |
-
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
|
|
|
| 56 |
return float(a), float(b), float(c)
|
| 57 |
|
| 58 |
|
|
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
|
|
| 111 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 112 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 113 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 114 |
-
_coeffs_list = _optimal_composition(
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
cushion=0.02)
|
| 118 |
|
| 119 |
|
| 120 |
# This code is adapted from:
|
|
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 148 |
|
| 149 |
X = X / (X.norm() + 1e-7)
|
| 150 |
hs = _coeffs_list[:steps] + list(
|
| 151 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 152 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 153 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 154 |
# Perform the NS iterations
|
|
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
|
|
| 183 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 184 |
|
| 185 |
hs = _coeffs_list[:steps] + list(
|
| 186 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 187 |
for a, b, c in hs:
|
| 188 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 189 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
+
LHS = np.array(
|
| 36 |
+
[
|
| 37 |
+
[l, l**3, l**5, 1],
|
| 38 |
+
[q, q**3, q**5, -1],
|
| 39 |
+
[r, r**3, r**5, 1],
|
| 40 |
+
[u, u**3, u**5, -1],
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 44 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
|
| 47 |
+
)
|
| 48 |
q, r = np.sqrt(
|
| 49 |
+
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
|
| 50 |
+
)
|
| 51 |
if not np.all(np.isfinite([q, r])):
|
| 52 |
+
raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
|
|
|
| 53 |
if abs(old_E - E) <= 1e-15:
|
| 54 |
break
|
| 55 |
else:
|
| 56 |
raise RuntimeError(
|
| 57 |
+
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
| 58 |
+
)
|
| 59 |
return float(a), float(b), float(c)
|
| 60 |
|
| 61 |
|
|
|
|
| 114 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 115 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 116 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 117 |
+
_coeffs_list = _optimal_composition(
|
| 118 |
+
l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
|
| 119 |
+
)
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
# This code is adapted from:
|
|
|
|
| 150 |
|
| 151 |
X = X / (X.norm() + 1e-7)
|
| 152 |
hs = _coeffs_list[:steps] + list(
|
| 153 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 154 |
+
)
|
| 155 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 156 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 157 |
# Perform the NS iterations
|
|
|
|
| 186 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 187 |
|
| 188 |
hs = _coeffs_list[:steps] + list(
|
| 189 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 190 |
+
)
|
| 191 |
for a, b, c in hs:
|
| 192 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 193 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
Returns:
|
| 19 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 23 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 24 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 25 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
|
|
|
|
|
|
| 26 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 27 |
"""
|
| 28 |
parts = normalize_fqn(name).split('.')
|
|
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 37 |
layer_idx = int(part)
|
| 38 |
break
|
| 39 |
|
| 40 |
-
if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
|
| 41 |
return kind, layer_idx
|
| 42 |
|
| 43 |
return None, -1
|
|
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 46 |
@dataclass
|
| 47 |
class QKClipInfo:
|
| 48 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 49 |
-
kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
|
| 50 |
indices: list[int] # which heads to consider for clipping
|
| 51 |
-
head_dim: int # from config
|
| 52 |
threshold: float # from config
|
| 53 |
logit: torch.Tensor | None
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 57 |
"""Extract QK clipping info for a named parameter.
|
| 58 |
|
| 59 |
Args:
|
| 60 |
clip_config: QK clipping configuration dict (or None).
|
|
|
|
|
|
|
| 61 |
n: Parameter name string.
|
| 62 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 63 |
|
|
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
|
|
| 70 |
head_dim = clip_config.get('head_dim')
|
| 71 |
threshold = clip_config.get('threshold')
|
| 72 |
kind, layer_idx = parse_qk_layer(n)
|
|
|
|
| 73 |
|
| 74 |
logit, indices = None, []
|
| 75 |
if qk_logits is not None and kind is not None:
|
| 76 |
logit = qk_logits[layer_idx]
|
| 77 |
-
indices_key = 'q_indices' if 'q' in kind else 'k_indices'
|
| 78 |
-
indices = clip_config.get(indices_key, []) or []
|
| 79 |
-
|
| 80 |
if isinstance(logit, DTensor):
|
| 81 |
# In TP settings, qk_logits may be DTensor
|
| 82 |
# We convert it to full tensor here for simplicity
|
| 83 |
logit = logit.full_tensor()
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
def compute_scales(p, qk_clip_state):
|
| 95 |
"""Compute per-head scaling factors for QK clipping.
|
| 96 |
|
| 97 |
-
Returns scales tensor if any head exceeds threshold, else None.
|
|
|
|
| 98 |
"""
|
| 99 |
kind = qk_clip_state.kind
|
| 100 |
indices = qk_clip_state.indices
|
|
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
|
|
| 118 |
if not head_scales:
|
| 119 |
return None
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 123 |
for head_idx, scale in head_scales.items():
|
| 124 |
scales_full[head_idx] = scale
|
| 125 |
return scales_full
|
| 126 |
|
| 127 |
|
| 128 |
-
def qk_clip(p, scales,
|
| 129 |
-
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
+
and return (kind, layer_index).
|
| 17 |
+
|
| 18 |
+
Supported kinds:
|
| 19 |
+
MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
|
| 20 |
+
MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
|
| 21 |
|
| 22 |
Returns:
|
| 23 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
|
|
| 27 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 28 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 29 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
| 30 |
+
'model.1.attn.wq_b.weight' -> ('wq_b', 1)
|
| 31 |
+
'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
|
| 32 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 33 |
"""
|
| 34 |
parts = normalize_fqn(name).split('.')
|
|
|
|
| 43 |
layer_idx = int(part)
|
| 44 |
break
|
| 45 |
|
| 46 |
+
if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
|
| 47 |
return kind, layer_idx
|
| 48 |
|
| 49 |
return None, -1
|
|
|
|
| 52 |
@dataclass
|
| 53 |
class QKClipInfo:
|
| 54 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 55 |
+
kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
|
| 56 |
indices: list[int] # which heads to consider for clipping
|
| 57 |
+
head_dim: int # from config (qk_head_dim for MLA wq_b)
|
| 58 |
threshold: float # from config
|
| 59 |
logit: torch.Tensor | None
|
| 60 |
|
| 61 |
+
# MLA-specific fields
|
| 62 |
+
is_mla: bool = False
|
| 63 |
+
qk_nope_head_dim: int = 0
|
| 64 |
+
qk_rope_head_dim: int = 0
|
| 65 |
+
v_head_dim: int = 0
|
| 66 |
+
|
| 67 |
|
| 68 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 69 |
"""Extract QK clipping info for a named parameter.
|
| 70 |
|
| 71 |
Args:
|
| 72 |
clip_config: QK clipping configuration dict (or None).
|
| 73 |
+
MHA/GQA keys: head_dim, threshold, q_indices, k_indices
|
| 74 |
+
MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
|
| 75 |
n: Parameter name string.
|
| 76 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 77 |
|
|
|
|
| 84 |
head_dim = clip_config.get('head_dim')
|
| 85 |
threshold = clip_config.get('threshold')
|
| 86 |
kind, layer_idx = parse_qk_layer(n)
|
| 87 |
+
is_mla = clip_config.get('is_mla', False)
|
| 88 |
|
| 89 |
logit, indices = None, []
|
| 90 |
if qk_logits is not None and kind is not None:
|
| 91 |
logit = qk_logits[layer_idx]
|
|
|
|
|
|
|
|
|
|
| 92 |
if isinstance(logit, DTensor):
|
| 93 |
# In TP settings, qk_logits may be DTensor
|
| 94 |
# We convert it to full tensor here for simplicity
|
| 95 |
logit = logit.full_tensor()
|
| 96 |
|
| 97 |
+
if kind in ('wq_b', 'wq', 'q_proj'):
|
| 98 |
+
indices = clip_config.get('q_indices', []) or []
|
| 99 |
+
elif kind in ('wkv_b', 'wk', 'k_proj'):
|
| 100 |
+
indices = clip_config.get('k_indices', []) or []
|
| 101 |
+
|
| 102 |
+
if is_mla:
|
| 103 |
+
return QKClipInfo(
|
| 104 |
+
kind=kind,
|
| 105 |
+
indices=indices,
|
| 106 |
+
head_dim=head_dim,
|
| 107 |
+
threshold=threshold,
|
| 108 |
+
logit=logit,
|
| 109 |
+
is_mla=True,
|
| 110 |
+
qk_nope_head_dim=clip_config['qk_nope_head_dim'],
|
| 111 |
+
qk_rope_head_dim=clip_config['qk_rope_head_dim'],
|
| 112 |
+
v_head_dim=clip_config['v_head_dim'],
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
return QKClipInfo(
|
| 116 |
+
kind=kind,
|
| 117 |
+
indices=indices,
|
| 118 |
+
head_dim=head_dim,
|
| 119 |
+
threshold=threshold,
|
| 120 |
+
logit=logit,
|
| 121 |
+
)
|
| 122 |
|
| 123 |
|
| 124 |
def compute_scales(p, qk_clip_state):
|
| 125 |
"""Compute per-head scaling factors for QK clipping.
|
| 126 |
|
| 127 |
+
Returns scales tensor (βΞ³ per head) if any head exceeds threshold, else None.
|
| 128 |
+
For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
|
| 129 |
"""
|
| 130 |
kind = qk_clip_state.kind
|
| 131 |
indices = qk_clip_state.indices
|
|
|
|
| 149 |
if not head_scales:
|
| 150 |
return None
|
| 151 |
|
| 152 |
+
# For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
|
| 153 |
+
if qk_clip_state.is_mla and kind == 'wkv_b':
|
| 154 |
+
effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
|
| 155 |
+
else:
|
| 156 |
+
effective_head_dim = head_dim
|
| 157 |
+
|
| 158 |
+
H_global = p.shape[0] // effective_head_dim
|
| 159 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 160 |
for head_idx, scale in head_scales.items():
|
| 161 |
scales_full[head_idx] = scale
|
| 162 |
return scales_full
|
| 163 |
|
| 164 |
|
| 165 |
+
def qk_clip(p, scales, info):
|
| 166 |
+
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
p: Parameter (nn.Parameter or raw tensor).
|
| 170 |
+
scales: [n_heads] tensor, each element = βΞ³_h.
|
| 171 |
+
info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
|
| 172 |
+
|
| 173 |
+
MLA sub-region scaling per Algorithm 1 (MuonClip):
|
| 174 |
+
wq_b: q_nope rows β βΞ³, q_pe rows β Ξ³
|
| 175 |
+
wkv_b: k_nope rows β βΞ³, v rows β unchanged
|
| 176 |
+
"""
|
| 177 |
+
W = p.data if isinstance(p, torch.nn.Parameter) else p
|
| 178 |
+
|
| 179 |
+
if not info.is_mla:
|
| 180 |
+
# MHA/GQA: uniform βΞ³ applied to all rows in each head
|
| 181 |
+
W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
# MLA: vectorized sub-region scaling within each head
|
| 185 |
+
if info.kind == 'wq_b':
|
| 186 |
+
qk_nope = info.qk_nope_head_dim
|
| 187 |
+
qk_head_dim = qk_nope + info.qk_rope_head_dim
|
| 188 |
+
W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
|
| 189 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β βΞ³
|
| 190 |
+
W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
|
| 191 |
+
1)) # q_pe β Ξ³
|
| 192 |
+
|
| 193 |
+
elif info.kind == 'wkv_b':
|
| 194 |
+
qk_nope = info.qk_nope_head_dim
|
| 195 |
+
kv_stride = qk_nope + info.v_head_dim
|
| 196 |
+
W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
|
| 197 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β βΞ³
|
| 198 |
+
# v rows: not touched (k_R shared rotary unchanged)
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8d53b78_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8d53b78_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8d53b78_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1999872
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:46baa92bf8f5ec5913df4081a01f662049fda475eb01bc7ed0f6154755fa88d5
|
| 3 |
size 1999872
|
|
@@ -93,10 +93,7 @@ class CPUOffloadPool:
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
-
cpu_flat = torch.empty(off,
|
| 97 |
-
dtype=dtype,
|
| 98 |
-
device="cpu",
|
| 99 |
-
pin_memory=True)
|
| 100 |
self._groups[dtype] = {
|
| 101 |
"indices": indices,
|
| 102 |
"offsets": offsets,
|
|
@@ -140,8 +137,7 @@ class CPUOffloadPool:
|
|
| 140 |
for i, mgd_idx in enumerate(indices):
|
| 141 |
local = self._local(self._managed[mgd_idx])
|
| 142 |
off, n = offsets[i]
|
| 143 |
-
cpu_flat[off:off + n].copy_(local.reshape(-1),
|
| 144 |
-
non_blocking=True)
|
| 145 |
|
| 146 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 147 |
|
|
@@ -159,8 +155,10 @@ class CPUOffloadPool:
|
|
| 159 |
)
|
| 160 |
|
| 161 |
if not self._logged:
|
| 162 |
-
logger.info(
|
| 163 |
-
|
|
|
|
|
|
|
| 164 |
|
| 165 |
# ------------------------------------------------------------------
|
| 166 |
def reload(self):
|
|
@@ -198,12 +196,11 @@ class CPUOffloadPool:
|
|
| 198 |
for i, mgd_idx in enumerate(indices):
|
| 199 |
local = self._local(self._managed[mgd_idx])
|
| 200 |
off, n = offsets[i]
|
| 201 |
-
local.reshape(-1).copy_(cpu_flat[off:off + n],
|
| 202 |
-
non_blocking=True)
|
| 203 |
|
| 204 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 205 |
|
| 206 |
if not self._logged:
|
| 207 |
-
logger.info(
|
| 208 |
-
|
| 209 |
-
|
|
|
|
| 93 |
indices.append(idx)
|
| 94 |
offsets.append((off, n))
|
| 95 |
off += n
|
| 96 |
+
cpu_flat = torch.empty(off, dtype=dtype, device="cpu", pin_memory=True)
|
|
|
|
|
|
|
|
|
|
| 97 |
self._groups[dtype] = {
|
| 98 |
"indices": indices,
|
| 99 |
"offsets": offsets,
|
|
|
|
| 137 |
for i, mgd_idx in enumerate(indices):
|
| 138 |
local = self._local(self._managed[mgd_idx])
|
| 139 |
off, n = offsets[i]
|
| 140 |
+
cpu_flat[off : off + n].copy_(local.reshape(-1), non_blocking=True)
|
|
|
|
| 141 |
|
| 142 |
offloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 143 |
|
|
|
|
| 155 |
)
|
| 156 |
|
| 157 |
if not self._logged:
|
| 158 |
+
logger.info(
|
| 159 |
+
"[CPUOffload] Offloaded %.2f MB (GPU β CPU)",
|
| 160 |
+
offloaded_bytes / (1024**2),
|
| 161 |
+
)
|
| 162 |
|
| 163 |
# ------------------------------------------------------------------
|
| 164 |
def reload(self):
|
|
|
|
| 196 |
for i, mgd_idx in enumerate(indices):
|
| 197 |
local = self._local(self._managed[mgd_idx])
|
| 198 |
off, n = offsets[i]
|
| 199 |
+
local.reshape(-1).copy_(cpu_flat[off : off + n], non_blocking=True)
|
|
|
|
| 200 |
|
| 201 |
reloaded_bytes += grp["total"] * cpu_flat.element_size()
|
| 202 |
|
| 203 |
if not self._logged:
|
| 204 |
+
logger.info(
|
| 205 |
+
"[CPUOffload] Reloaded %.2f MB (CPU β GPU)", reloaded_bytes / (1024**2)
|
| 206 |
+
)
|
|
@@ -360,7 +360,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
@@ -407,7 +407,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
-
qk_clip(p, scales_full, qk_clip_state
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
@@ -1050,12 +1050,16 @@ class Muon(torch.optim.Optimizer):
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1054 |
return super().state_dict()
|
| 1055 |
|
| 1056 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1057 |
if self.cpu_offload:
|
| 1058 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 1059 |
super().load_state_dict(state_dict)
|
| 1060 |
|
| 1061 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
|
|
| 360 |
scales_full = compute_scales(
|
| 361 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 362 |
if scales_full is not None:
|
| 363 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 364 |
|
| 365 |
def distributed_muon(
|
| 366 |
self,
|
|
|
|
| 407 |
scales_full = compute_scales(
|
| 408 |
p, qk_clip_state) if qk_clip_state is not None else None
|
| 409 |
if scales_full is not None:
|
| 410 |
+
qk_clip(p, scales_full, qk_clip_state)
|
| 411 |
|
| 412 |
if not dtensor_params:
|
| 413 |
return
|
|
|
|
| 1050 |
|
| 1051 |
def state_dict(self) -> dict:
|
| 1052 |
if self.cpu_offload:
|
| 1053 |
+
raise RuntimeError(
|
| 1054 |
+
"Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save."
|
| 1055 |
+
)
|
| 1056 |
return super().state_dict()
|
| 1057 |
|
| 1058 |
def load_state_dict(self, state_dict: dict) -> None:
|
| 1059 |
if self.cpu_offload:
|
| 1060 |
+
raise RuntimeError(
|
| 1061 |
+
"Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load."
|
| 1062 |
+
)
|
| 1063 |
super().load_state_dict(state_dict)
|
| 1064 |
|
| 1065 |
# Invalidate adamw.py's module-level tensor caches so that
|
|
@@ -32,27 +32,30 @@ def _optimal_quintic(l, u, max_iter=1000):
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
-
LHS = np.array(
|
| 36 |
-
[
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 42 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 43 |
-
raise ValueError(
|
| 44 |
-
|
|
|
|
| 45 |
q, r = np.sqrt(
|
| 46 |
-
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) /
|
| 47 |
-
|
| 48 |
if not np.all(np.isfinite([q, r])):
|
| 49 |
-
raise ValueError(
|
| 50 |
-
f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
| 51 |
if abs(old_E - E) <= 1e-15:
|
| 52 |
break
|
| 53 |
else:
|
| 54 |
raise RuntimeError(
|
| 55 |
-
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
|
|
|
| 56 |
return float(a), float(b), float(c)
|
| 57 |
|
| 58 |
|
|
@@ -111,10 +114,9 @@ def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
|
|
| 111 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 112 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 113 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 114 |
-
_coeffs_list = _optimal_composition(
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
cushion=0.02)
|
| 118 |
|
| 119 |
|
| 120 |
# This code is adapted from:
|
|
@@ -148,7 +150,8 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 148 |
|
| 149 |
X = X / (X.norm() + 1e-7)
|
| 150 |
hs = _coeffs_list[:steps] + list(
|
| 151 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 152 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 153 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 154 |
# Perform the NS iterations
|
|
@@ -183,7 +186,8 @@ def _zeropower_via_newtonschulz5_batched(G, steps):
|
|
| 183 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 184 |
|
| 185 |
hs = _coeffs_list[:steps] + list(
|
| 186 |
-
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
|
|
|
| 187 |
for a, b, c in hs:
|
| 188 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 189 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
|
|
| 32 |
E = inf
|
| 33 |
for _ in range(max_iter):
|
| 34 |
old_E = E
|
| 35 |
+
LHS = np.array(
|
| 36 |
+
[
|
| 37 |
+
[l, l**3, l**5, 1],
|
| 38 |
+
[q, q**3, q**5, -1],
|
| 39 |
+
[r, r**3, r**5, 1],
|
| 40 |
+
[u, u**3, u**5, -1],
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
|
| 44 |
if not np.all(np.isfinite([a, b, c, E])):
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
|
| 47 |
+
)
|
| 48 |
q, r = np.sqrt(
|
| 49 |
+
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
|
| 50 |
+
)
|
| 51 |
if not np.all(np.isfinite([q, r])):
|
| 52 |
+
raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
|
|
|
|
| 53 |
if abs(old_E - E) <= 1e-15:
|
| 54 |
break
|
| 55 |
else:
|
| 56 |
raise RuntimeError(
|
| 57 |
+
f"_optimal_quintic: did not converge after {max_iter} iterations"
|
| 58 |
+
)
|
| 59 |
return float(a), float(b), float(c)
|
| 60 |
|
| 61 |
|
|
|
|
| 114 |
# - Polar Express: analytically optimal per step, adapting to the shrinking
|
| 115 |
# singular-value interval [l, u] as iterations progress; converges all
|
| 116 |
# singular values to 1, producing the exact polar factor UV^T.
|
| 117 |
+
_coeffs_list = _optimal_composition(
|
| 118 |
+
l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
|
| 119 |
+
)
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
# This code is adapted from:
|
|
|
|
| 150 |
|
| 151 |
X = X / (X.norm() + 1e-7)
|
| 152 |
hs = _coeffs_list[:steps] + list(
|
| 153 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 154 |
+
)
|
| 155 |
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 156 |
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
|
| 157 |
# Perform the NS iterations
|
|
|
|
| 186 |
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
| 187 |
|
| 188 |
hs = _coeffs_list[:steps] + list(
|
| 189 |
+
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
|
| 190 |
+
)
|
| 191 |
for a, b, c in hs:
|
| 192 |
buf1 = torch.bmm(X, X.transpose(-2, -1))
|
| 193 |
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
|
|
@@ -13,7 +13,11 @@ logger = logging.getLogger(__name__)
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
Returns:
|
| 19 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
@@ -23,6 +27,8 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 23 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 24 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 25 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
|
|
|
|
|
|
| 26 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 27 |
"""
|
| 28 |
parts = normalize_fqn(name).split('.')
|
|
@@ -37,7 +43,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 37 |
layer_idx = int(part)
|
| 38 |
break
|
| 39 |
|
| 40 |
-
if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
|
| 41 |
return kind, layer_idx
|
| 42 |
|
| 43 |
return None, -1
|
|
@@ -46,18 +52,26 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 46 |
@dataclass
|
| 47 |
class QKClipInfo:
|
| 48 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 49 |
-
kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
|
| 50 |
indices: list[int] # which heads to consider for clipping
|
| 51 |
-
head_dim: int # from config
|
| 52 |
threshold: float # from config
|
| 53 |
logit: torch.Tensor | None
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 57 |
"""Extract QK clipping info for a named parameter.
|
| 58 |
|
| 59 |
Args:
|
| 60 |
clip_config: QK clipping configuration dict (or None).
|
|
|
|
|
|
|
| 61 |
n: Parameter name string.
|
| 62 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 63 |
|
|
@@ -70,31 +84,48 @@ def get_qk_clip_info(clip_config, n, qk_logits):
|
|
| 70 |
head_dim = clip_config.get('head_dim')
|
| 71 |
threshold = clip_config.get('threshold')
|
| 72 |
kind, layer_idx = parse_qk_layer(n)
|
|
|
|
| 73 |
|
| 74 |
logit, indices = None, []
|
| 75 |
if qk_logits is not None and kind is not None:
|
| 76 |
logit = qk_logits[layer_idx]
|
| 77 |
-
indices_key = 'q_indices' if 'q' in kind else 'k_indices'
|
| 78 |
-
indices = clip_config.get(indices_key, []) or []
|
| 79 |
-
|
| 80 |
if isinstance(logit, DTensor):
|
| 81 |
# In TP settings, qk_logits may be DTensor
|
| 82 |
# We convert it to full tensor here for simplicity
|
| 83 |
logit = logit.full_tensor()
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
def compute_scales(p, qk_clip_state):
|
| 95 |
"""Compute per-head scaling factors for QK clipping.
|
| 96 |
|
| 97 |
-
Returns scales tensor if any head exceeds threshold, else None.
|
|
|
|
| 98 |
"""
|
| 99 |
kind = qk_clip_state.kind
|
| 100 |
indices = qk_clip_state.indices
|
|
@@ -118,18 +149,50 @@ def compute_scales(p, qk_clip_state):
|
|
| 118 |
if not head_scales:
|
| 119 |
return None
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 123 |
for head_idx, scale in head_scales.items():
|
| 124 |
scales_full[head_idx] = scale
|
| 125 |
return scales_full
|
| 126 |
|
| 127 |
|
| 128 |
-
def qk_clip(p, scales,
|
| 129 |
-
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
| 14 |
"""
|
| 15 |
Parse a parameter name to check if it is a query/key projection layer
|
| 16 |
+
and return (kind, layer_index).
|
| 17 |
+
|
| 18 |
+
Supported kinds:
|
| 19 |
+
MHA/GQA: 'wq', 'wk', 'q_proj', 'k_proj'
|
| 20 |
+
MLA: 'wq_b' (Q up-proj), 'wkv_b' (KV up-proj)
|
| 21 |
|
| 22 |
Returns:
|
| 23 |
(kind, layer_idx) or (None, -1) if not matched.
|
|
|
|
| 27 |
'model.5.attn.wk.weight' -> ('wk', 5)
|
| 28 |
'model.2.attn.q_proj.weight' -> ('q_proj', 2)
|
| 29 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
| 30 |
+
'model.1.attn.wq_b.weight' -> ('wq_b', 1)
|
| 31 |
+
'model.0.attn.wkv_b.weight' -> ('wkv_b', 0)
|
| 32 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 33 |
"""
|
| 34 |
parts = normalize_fqn(name).split('.')
|
|
|
|
| 43 |
layer_idx = int(part)
|
| 44 |
break
|
| 45 |
|
| 46 |
+
if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
|
| 47 |
return kind, layer_idx
|
| 48 |
|
| 49 |
return None, -1
|
|
|
|
| 52 |
@dataclass
|
| 53 |
class QKClipInfo:
|
| 54 |
"""Per-parameter dynamic info computed from config + runtime logits."""
|
| 55 |
+
kind: str | None # 'wq'/'q_proj'/'wq_b' or 'wk'/'k_proj'/'wkv_b' or None
|
| 56 |
indices: list[int] # which heads to consider for clipping
|
| 57 |
+
head_dim: int # from config (qk_head_dim for MLA wq_b)
|
| 58 |
threshold: float # from config
|
| 59 |
logit: torch.Tensor | None
|
| 60 |
|
| 61 |
+
# MLA-specific fields
|
| 62 |
+
is_mla: bool = False
|
| 63 |
+
qk_nope_head_dim: int = 0
|
| 64 |
+
qk_rope_head_dim: int = 0
|
| 65 |
+
v_head_dim: int = 0
|
| 66 |
+
|
| 67 |
|
| 68 |
def get_qk_clip_info(clip_config, n, qk_logits):
|
| 69 |
"""Extract QK clipping info for a named parameter.
|
| 70 |
|
| 71 |
Args:
|
| 72 |
clip_config: QK clipping configuration dict (or None).
|
| 73 |
+
MHA/GQA keys: head_dim, threshold, q_indices, k_indices
|
| 74 |
+
MLA extra keys: is_mla=True, qk_nope_head_dim, qk_rope_head_dim, v_head_dim
|
| 75 |
n: Parameter name string.
|
| 76 |
qk_logits: Dict mapping layer indices to logit tensors (or None).
|
| 77 |
|
|
|
|
| 84 |
head_dim = clip_config.get('head_dim')
|
| 85 |
threshold = clip_config.get('threshold')
|
| 86 |
kind, layer_idx = parse_qk_layer(n)
|
| 87 |
+
is_mla = clip_config.get('is_mla', False)
|
| 88 |
|
| 89 |
logit, indices = None, []
|
| 90 |
if qk_logits is not None and kind is not None:
|
| 91 |
logit = qk_logits[layer_idx]
|
|
|
|
|
|
|
|
|
|
| 92 |
if isinstance(logit, DTensor):
|
| 93 |
# In TP settings, qk_logits may be DTensor
|
| 94 |
# We convert it to full tensor here for simplicity
|
| 95 |
logit = logit.full_tensor()
|
| 96 |
|
| 97 |
+
if kind in ('wq_b', 'wq', 'q_proj'):
|
| 98 |
+
indices = clip_config.get('q_indices', []) or []
|
| 99 |
+
elif kind in ('wkv_b', 'wk', 'k_proj'):
|
| 100 |
+
indices = clip_config.get('k_indices', []) or []
|
| 101 |
+
|
| 102 |
+
if is_mla:
|
| 103 |
+
return QKClipInfo(
|
| 104 |
+
kind=kind,
|
| 105 |
+
indices=indices,
|
| 106 |
+
head_dim=head_dim,
|
| 107 |
+
threshold=threshold,
|
| 108 |
+
logit=logit,
|
| 109 |
+
is_mla=True,
|
| 110 |
+
qk_nope_head_dim=clip_config['qk_nope_head_dim'],
|
| 111 |
+
qk_rope_head_dim=clip_config['qk_rope_head_dim'],
|
| 112 |
+
v_head_dim=clip_config['v_head_dim'],
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
return QKClipInfo(
|
| 116 |
+
kind=kind,
|
| 117 |
+
indices=indices,
|
| 118 |
+
head_dim=head_dim,
|
| 119 |
+
threshold=threshold,
|
| 120 |
+
logit=logit,
|
| 121 |
+
)
|
| 122 |
|
| 123 |
|
| 124 |
def compute_scales(p, qk_clip_state):
|
| 125 |
"""Compute per-head scaling factors for QK clipping.
|
| 126 |
|
| 127 |
+
Returns scales tensor (βΞ³ per head) if any head exceeds threshold, else None.
|
| 128 |
+
For MLA wkv_b, effective row stride is qk_nope_head_dim + v_head_dim.
|
| 129 |
"""
|
| 130 |
kind = qk_clip_state.kind
|
| 131 |
indices = qk_clip_state.indices
|
|
|
|
| 149 |
if not head_scales:
|
| 150 |
return None
|
| 151 |
|
| 152 |
+
# For MLA wkv_b, each KV head spans qk_nope_head_dim + v_head_dim rows
|
| 153 |
+
if qk_clip_state.is_mla and kind == 'wkv_b':
|
| 154 |
+
effective_head_dim = qk_clip_state.qk_nope_head_dim + qk_clip_state.v_head_dim
|
| 155 |
+
else:
|
| 156 |
+
effective_head_dim = head_dim
|
| 157 |
+
|
| 158 |
+
H_global = p.shape[0] // effective_head_dim
|
| 159 |
scales_full = torch.ones(H_global, device=p.data.device)
|
| 160 |
for head_idx, scale in head_scales.items():
|
| 161 |
scales_full[head_idx] = scale
|
| 162 |
return scales_full
|
| 163 |
|
| 164 |
|
| 165 |
+
def qk_clip(p, scales, info):
|
| 166 |
+
"""Apply per-head scaling to a Q/K projection weight matrix.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
p: Parameter (nn.Parameter or raw tensor).
|
| 170 |
+
scales: [n_heads] tensor, each element = βΞ³_h.
|
| 171 |
+
info: QKClipInfo with kind, head_dim, and MLA sub-head dimensions.
|
| 172 |
+
|
| 173 |
+
MLA sub-region scaling per Algorithm 1 (MuonClip):
|
| 174 |
+
wq_b: q_nope rows β βΞ³, q_pe rows β Ξ³
|
| 175 |
+
wkv_b: k_nope rows β βΞ³, v rows β unchanged
|
| 176 |
+
"""
|
| 177 |
+
W = p.data if isinstance(p, torch.nn.Parameter) else p
|
| 178 |
+
|
| 179 |
+
if not info.is_mla:
|
| 180 |
+
# MHA/GQA: uniform βΞ³ applied to all rows in each head
|
| 181 |
+
W.view(-1, info.head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
# MLA: vectorized sub-region scaling within each head
|
| 185 |
+
if info.kind == 'wq_b':
|
| 186 |
+
qk_nope = info.qk_nope_head_dim
|
| 187 |
+
qk_head_dim = qk_nope + info.qk_rope_head_dim
|
| 188 |
+
W_3d = W.view(-1, qk_head_dim, W.shape[1]) # [H, qk_head_dim, in_dim]
|
| 189 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # q_nope β βΞ³
|
| 190 |
+
W_3d[:, qk_nope:, :].mul_((scales * scales).view(-1, 1,
|
| 191 |
+
1)) # q_pe β Ξ³
|
| 192 |
+
|
| 193 |
+
elif info.kind == 'wkv_b':
|
| 194 |
+
qk_nope = info.qk_nope_head_dim
|
| 195 |
+
kv_stride = qk_nope + info.v_head_dim
|
| 196 |
+
W_3d = W.view(-1, kv_stride, W.shape[1]) # [H, kv_stride, in_dim]
|
| 197 |
+
W_3d[:, :qk_nope, :].mul_(scales.view(-1, 1, 1)) # k_nope β βΞ³
|
| 198 |
+
# v rows: not touched (k_R shared rotary unchanged)
|
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8d53b78_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8d53b78_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8d53b78_dirty::{op_name}"
|
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1865080
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bcf5b8838dfaf6e81fdbd52ff4638ca76abaa678f7c2cbd81cf03dc72f9cd5d2
|
| 3 |
size 1865080
|