QK-Clip for MuonClip Optimizer (MLA)
Reference: Kimi K2 Technical Report, Section 2.1, Algorithm 1
๊ฐ์
QK-Clip์ Muon optimizer์์ ๋ฐ์ํ๋ attention logit explosion์ ๋ฐฉ์งํ๊ธฐ ์ํ weight rescaling ๊ธฐ๋ฒ์ด๋ค. forward/backward์๋ ๊ฐ์ ํ์ง ์๊ณ , optimizer step ์ดํ์ weight๋ฅผ rescaleํ์ฌ logit ์ฑ์ฅ์ ์์ฒ ์ฐจ๋จํ๋ค.
Algorithm 1: MuonClip
for each training step t:
// 1. Muon optimizer step
for each weight W:
Mt = ยตยทMt-1 + Gt
Ot = Newton-Schulz(Mt) ยท โmax(n,m) ยท 0.2
Wt = Wt-1 - ฮทยท(Ot + ฮปยทWt-1)
// 2. QK-Clip
for each attention head h:
S^h_max โ forward์์ ๊ธฐ๋กํ head h์ max pre-softmax logit
if S^h_max > ฯ:
ฮณ โ ฯ / S^h_max
W^h_qc โ W^h_qc ยท โฮณ (query compressed, q_nope)
W^h_kc โ W^h_kc ยท โฮณ (key compressed, k_nope)
W^h_qr โ W^h_qr ยท ฮณ (query rotary, q_pe)
// k_R (shared rotary, k_pe): ์ ๊ฑด๋๋ฆผ
๊ธฐ์กด ์ฝ๋ โ MLA ์๋์ฝ๋
ํ์ฌ ์ฝ๋ ๊ตฌ์กฐ (MHA/GQA)
parse_qk_layer(name) โ wq/wk ์ฌ๋ถ ํ๋ณ, layer index ์ถ์ถ
get_qk_clip_info(config, n) โ QKClipInfo (kind, indices, head_dim, threshold, logit)
compute_scales(p, info) โ per-head โฮณ scales ํ
์ ๋ฐํ
qk_clip(p, scales, head_dim) โ W.view(-1, head_dim, in_dim).mul_(scales)
ํ์ฌ ์ฝ๋๋ head_dim์ด ๊ท ์ผํ๊ณ , Q/K weight ์ ์ฒด์ ๋์ผํ โฮณ๋ฅผ ์ ์ฉํ๋ค.
MLA์์ ๋ฌ๋ผ์ง๋ ์
| ํญ๋ชฉ | MHA/GQA (ํ์ฌ) | MLA |
|---|---|---|
| Q weight | wq / q_proj |
wq_b (up-proj from LoRA) |
| K weight | wk / k_proj |
wkv_b (k_nope + v ํฉ์ณ์ ธ ์์) |
| Q head stride | qk_head_dim (๊ท ์ผ) |
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim |
| K head stride | qk_head_dim (๊ท ์ผ) |
kv_stride = qk_nope_head_dim + v_head_dim |
| Q scaling | ์ ์ฒด โฮณ | nope โ โฮณ, rope โ ฮณ (์๋ก ๋ค๋ฆ) |
| K scaling | ์ ์ฒด โฮณ | k_nope โ โฮณ, v โ 1.0 (๋ถ๋ถ๋ง) |
| shared k_pe | ์์ | wkv_a ๋ท๋ถ๋ถ, ์ ๊ฑด๋๋ฆผ |
์๋์ฝ๋: parse_qk_layer (MLA ํ์ฅ)
def parse_qk_layer(name: str) -> tuple[str | None, int]:
parts = normalize_fqn(name).split('.')
kind = parts[-2]
layer_idx = -1
for part in reversed(parts):
if part.isdigit():
layer_idx = int(part)
break
# MHA/GQA: wq, wk, q_proj, k_proj
# MLA: wq_b (Q up-proj), wkv_b (KV up-proj)
if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
return kind, layer_idx
return None, -1
์๋์ฝ๋: QKClipInfo (MLA ํ์ฅ)
@dataclass
class QKClipInfo:
kind: str | None # 'wq_b' or 'wkv_b' (MLA) / 'wq','wk' (MHA)
indices: list[int] # clipping ๋์ head indices
head_dim: int # ๊ธฐ์กด MHA์ฉ (uniform stride)
threshold: float
logit: torch.Tensor | None
# MLA ์ ์ฉ ํ๋
is_mla: bool = False
qk_nope_head_dim: int = 0
qk_rope_head_dim: int = 0
v_head_dim: int = 0
์๋์ฝ๋: get_qk_clip_info (MLA ํ์ฅ)
def get_qk_clip_info(clip_config, n, qk_logits):
if clip_config is None:
return None
threshold = clip_config['threshold']
kind, layer_idx = parse_qk_layer(n)
is_mla = clip_config.get('is_mla', False)
logit, indices = None, []
if qk_logits is not None and kind is not None:
logit = qk_logits[layer_idx]
if isinstance(logit, DTensor):
logit = logit.full_tensor()
if kind in ('wq_b', 'wq', 'q_proj'):
indices = clip_config.get('q_indices', []) or []
elif kind in ('wkv_b', 'wk', 'k_proj'):
indices = clip_config.get('k_indices', []) or []
if is_mla:
return QKClipInfo(
kind=kind,
indices=indices,
head_dim=clip_config['head_dim'], # qk_head_dim (for wq_b)
threshold=threshold,
logit=logit,
is_mla=True,
qk_nope_head_dim=clip_config['qk_nope_head_dim'],
qk_rope_head_dim=clip_config['qk_rope_head_dim'],
v_head_dim=clip_config['v_head_dim'],
)
else:
# ๊ธฐ์กด MHA/GQA ๊ฒฝ๋ก
return QKClipInfo(
kind=kind, indices=indices,
head_dim=clip_config['head_dim'],
threshold=threshold, logit=logit,
)
์๋์ฝ๋: compute_scales (MLA ํ์ฅ)
๊ธฐ์กด๊ณผ ๋์ผํ๊ฒ per-head ฮณ๋ฅผ ๊ณ์ฐํ๋ค. (ฮณ ๊ฒฐ์ ์ MHA์ ๋์ผ)
๋ฌ๋ผ์ง๋ ๊ฑด qk_clip ์ ์ฉ ์ head ๋ด๋ถ๋ฅผ sub-region๋ณ๋ก ๋๋ ์ ๋ค๋ฅธ ๋ณํ์ ์ฐ๋ ๊ฒ์ด๋ค.
def compute_scales(p, qk_clip_state):
"""๊ธฐ์กด ์ฝ๋์ ๋์ผ. per-head โฮณ ๋ฐํ."""
kind = qk_clip_state.kind
indices = qk_clip_state.indices
threshold = qk_clip_state.threshold
logit = qk_clip_state.logit
head_scales = {}
for logit_idx, head_idx in enumerate(indices):
v_ele = float(logit[logit_idx])
if v_ele > threshold:
new_scale = math.sqrt(threshold / v_ele) # โฮณ
if head_idx not in head_scales or new_scale < head_scales[head_idx]:
head_scales[head_idx] = new_scale
if not head_scales:
return None
H_global = p.shape[0] // qk_clip_state.head_dim # MLA: head_dim = qk_head_dim or kv_stride
scales_full = torch.ones(H_global, device=p.data.device)
for head_idx, scale in head_scales.items():
scales_full[head_idx] = scale # โฮณ_h
return scales_full
์๋์ฝ๋: qk_clip (MLA ํ์ฅ)
per-head scales(โฮณ)๋ ๋์ผํ๊ฒ ๋ฐ๋, head ๋ด๋ถ sub-region์ ๋ค๋ฅธ ํจ์๋ฅผ ์ ์ฉํ๋ค.
def qk_clip(p, scales, head_dim, is_mla=False, kind=None, info=None):
"""
scales: [n_heads] ํ
์, ๊ฐ ์์ = โฮณ_h
is_mla=False: ๊ธฐ์กด MHA/GQA (head ๋ด uniform โฮณ)
is_mla=True: MLA (head ๋ด sub-region๋ณ ๋ค๋ฅธ ๋ณํ)
"""
W = p.data if isinstance(p, torch.nn.Parameter) else p
if not is_mla:
# ๊ธฐ์กด: ๋ชจ๋ ํ์ โฮณ ๊ท ์ผ ์ ์ฉ
W.view(-1, head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
return
# MLA: head๋ณ๋ก sub-region ๋ถ๋ฆฌ ์ ์ฉ
if kind == 'wq_b':
qk_nope = info.qk_nope_head_dim
qk_rope = info.qk_rope_head_dim
qk_head_dim = qk_nope + qk_rope
for h in range(len(scales)):
sqrt_gamma = scales[h].item()
if sqrt_gamma >= 1.0:
continue
gamma = sqrt_gamma * sqrt_gamma # โฮณ โ ฮณ
s = h * qk_head_dim
W[s : s + qk_nope] *= sqrt_gamma # q_nope โ โฮณ
W[s + qk_nope : s + qk_head_dim] *= gamma # q_pe โ ฮณ
elif kind == 'wkv_b':
qk_nope = info.qk_nope_head_dim
kv_stride = qk_nope + info.v_head_dim
for h in range(len(scales)):
sqrt_gamma = scales[h].item()
if sqrt_gamma >= 1.0:
continue
s = h * kv_stride
W[s : s + qk_nope] *= sqrt_gamma # k_nope โ โฮณ
# v ํ: ์ ๊ฑด๋๋ฆผ
์๋์ฝ๋: GQA์์ wkv_b indices ์ฒ๋ฆฌ
Q head โ KV head ๋งคํ์ด ํ์ํ๋ค. ์ฌ๋ฌ Q head๊ฐ ๊ฐ์ KV head๋ฅผ ๊ณต์ ํ๋ฏ๋ก, group ๋ด ์ต์ gamma ๊ธฐ์ค์ผ๋ก ํ ๋ฒ๋ง ์ ์ฉํด์ผ ํ๋ค.
def build_k_indices_for_mla(clip_config, n_heads, n_kv_heads):
"""
Q head ๊ธฐ์ค logit์ผ๋ก๋ถํฐ KV head indices๋ฅผ ์์ฑํ๋ค.
q_indices๊ฐ Q head index ๊ธฐ์ค์ด๋ผ๋ฉด,
k_indices๋ ๋์๋๋ KV head index๋ก ๋ณํํด์ผ ํ๋ค.
์ฃผ์: ๊ฐ์ KV head์ ๋งคํ๋๋ ์ฌ๋ฌ Q head ์ค
๊ฐ์ฅ ํฐ logit (= ๊ฐ์ฅ ์์ gamma)์ ์ฌ์ฉํด์ผ ํ๋ค.
"""
heads_per_kv = n_heads // n_kv_heads
q_indices = clip_config.get('q_indices', list(range(n_heads)))
# Q head โ KV head ๋งคํ
# logit ํ
์์์ ๊ฐ์ kv_head์ ๋์๋๋ Q head๋ค ์ค max๋ฅผ ์ทจํ๋ ๊ฒ์
# compute_scales_mla ๋ด๋ถ์์ min(gamma) ๋ก ์ฒ๋ฆฌ๋จ
k_indices = []
seen = set()
for q_idx in q_indices:
kv_idx = q_idx // heads_per_kv
if kv_idx not in seen:
k_indices.append(kv_idx)
seen.add(kv_idx)
return k_indices
์๋์ฝ๋: ํธ์ถ ํ๋ฆ (ํตํฉ)
# optimizer step ์ดํ ํธ์ถ๋๋ ๋ถ๋ถ (๊ธฐ์กด ์ฝ๋ ๊ตฌ์กฐ ์ ์ง)
for name, param in model.named_parameters():
info = get_qk_clip_info(clip_config, name, qk_logits)
if info is None or info.kind is None:
continue
scales = compute_scales(param, info) # per-head โฮณ (MHA/MLA ๊ณตํต)
if scales is not None:
qk_clip(param, scales, info.head_dim,
is_mla=info.is_mla, kind=info.kind, info=info)
์๋์ฝ๋: clip_config ์์
# MHA/GQA (๊ธฐ์กด)
clip_config = {
'head_dim': 128,
'threshold': 100.0,
'q_indices': list(range(n_heads)),
'k_indices': list(range(n_kv_heads)),
}
# MLA (ํ์ฅ)
clip_config = {
'is_mla': True,
'head_dim': 192, # qk_head_dim (= qk_nope + qk_rope)
'qk_nope_head_dim': 128,
'qk_rope_head_dim': 64,
'v_head_dim': 128,
'threshold': 100.0,
'q_indices': list(range(n_heads)),
'k_indices': list(range(n_kv_heads)), # build_k_indices_for_mla๋ก ์์ฑ
}
ํ ์ธ๋ฑ์ค ๋งคํ ํ ์ด๋ธ
| ์๊ณ ๋ฆฌ์ฆ ๊ธฐํธ | ํ ์ | ํ ๋ฒ์ | scale |
|---|---|---|---|
| W^h_qc | wq_b.weight |
[h*qk_head_dim : h*qk_head_dim + qk_nope_head_dim] |
โฮณ |
| W^h_qr | wq_b.weight |
[h*qk_head_dim + qk_nope_head_dim : (h+1)*qk_head_dim] |
ฮณ |
| W^h_kc | wkv_b.weight |
[kv_h*kv_stride : kv_h*kv_stride + qk_nope_head_dim] |
โฮณ |
| k_R | wkv_a output ๋ท๋ถ๋ถ |
- | ์ ๊ฑด๋๋ฆผ |
kv_stride = qk_nope_head_dim + v_head_dimkv_h = h // (n_heads // n_kv_heads)(GQA head ๋งคํ)
ํ์ดํผํ๋ผ๋ฏธํฐ
| ํ๋ผ๋ฏธํฐ | ๊ฐ | ๋น๊ณ |
|---|---|---|
| ฯ (threshold) | 100 | K2 full-scale ํ์ต |
| ฯ (aggressive) | 30 | ์๊ท๋ชจ ablation, ์ฑ๋ฅ ์ ํ ์์ ํ์ธ |
์ฐธ๊ณ ์ฌํญ
- Self-deactivation: K2์์ ์ด๊ธฐ 70k step ๋์ 12.7%์ head๋ง trigger๋จ. ์ดํ ๋ชจ๋ head์ S_max๊ฐ ฯ ์๋๋ก ๋ด๋ ค๊ฐ๋ฉด์ ์์ฐ์ค๋ฝ๊ฒ ๋นํ์ฑํ.
- DP/TP ํ๊ฒฝ: S^h_max๋ฅผ all-reduce๋ก ๋ชจ๋ rank์์ max ์์ง ํ์.
- GQA ์ค๋ณต ์ ์ฉ ๋ฐฉ์ง: ๊ฐ์ KV head๋ฅผ ๊ณต์ ํ๋ Q head group์์ ๊ฐ์ฅ ์์ gamma(= ๊ฐ์ฅ ํฐ logit)๋ฅผ ๊ธฐ์ค์ผ๋ก KV weight๋ฅผ ํ ๋ฒ๋ง scaling.
compute_scales_mla์์min(gamma)๋ก์ง์ผ๋ก ์ฒ๋ฆฌ. - wq_b_gate: attention logit์ด ์๋ output gate์๋ง ๊ด์ฌํ๋ฏ๋ก QK-Clip ๋์ ์๋.
- ๊ธฐ์กด logit soft-cap: forward-level safety net์ผ๋ก ๋จ๊ฒจ๋๋, optimizer-level QK-Clip์ ์ถ๊ฐํ๋ ๊ฒ์ด ๋ ผ๋ฌธ์ ์ ๊ทผ๋ฒ.