# QK-Clip for MuonClip Optimizer (MLA) > Reference: [Kimi K2 Technical Report](https://arxiv.org/pdf/2507.20534), Section 2.1, Algorithm 1 ## 개요 QK-Clip은 Muon optimizer에서 발생하는 attention logit explosion을 방지하기 위한 **weight rescaling** 기법이다. forward/backward에는 개입하지 않고, optimizer step **이후**에 weight를 rescale하여 logit 성장을 원천 차단한다. ## Algorithm 1: MuonClip ``` for each training step t: // 1. Muon optimizer step for each weight W: Mt = µ·Mt-1 + Gt Ot = Newton-Schulz(Mt) · √max(n,m) · 0.2 Wt = Wt-1 - η·(Ot + λ·Wt-1) // 2. QK-Clip for each attention head h: S^h_max ← forward에서 기록한 head h의 max pre-softmax logit if S^h_max > τ: γ ← τ / S^h_max W^h_qc ← W^h_qc · √γ (query compressed, q_nope) W^h_kc ← W^h_kc · √γ (key compressed, k_nope) W^h_qr ← W^h_qr · γ (query rotary, q_pe) // k_R (shared rotary, k_pe): 안 건드림 ``` ## 기존 코드 → MLA 수도코드 ### 현재 코드 구조 (MHA/GQA) ``` parse_qk_layer(name) → wq/wk 여부 판별, layer index 추출 get_qk_clip_info(config, n) → QKClipInfo (kind, indices, head_dim, threshold, logit) compute_scales(p, info) → per-head √γ scales 텐서 반환 qk_clip(p, scales, head_dim) → W.view(-1, head_dim, in_dim).mul_(scales) ``` 현재 코드는 head_dim이 균일하고, Q/K weight 전체에 동일한 √γ를 적용한다. ### MLA에서 달라지는 점 | 항목 | MHA/GQA (현재) | MLA | |---|---|---| | Q weight | `wq` / `q_proj` | `wq_b` (up-proj from LoRA) | | K weight | `wk` / `k_proj` | `wkv_b` (k_nope + v 합쳐져 있음) | | Q head stride | `qk_head_dim` (균일) | `qk_head_dim` = `qk_nope_head_dim + qk_rope_head_dim` | | K head stride | `qk_head_dim` (균일) | `kv_stride` = `qk_nope_head_dim + v_head_dim` | | Q scaling | 전체 √γ | nope → √γ, rope → γ (서로 다름) | | K scaling | 전체 √γ | k_nope → √γ, v → 1.0 (부분만) | | shared k_pe | 없음 | `wkv_a` 뒷부분, 안 건드림 | ### 수도코드: parse_qk_layer (MLA 확장) ```python def parse_qk_layer(name: str) -> tuple[str | None, int]: parts = normalize_fqn(name).split('.') kind = parts[-2] layer_idx = -1 for part in reversed(parts): if part.isdigit(): layer_idx = int(part) break # MHA/GQA: wq, wk, q_proj, k_proj # MLA: wq_b (Q up-proj), wkv_b (KV up-proj) if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): return kind, layer_idx return None, -1 ``` ### 수도코드: QKClipInfo (MLA 확장) ```python @dataclass class QKClipInfo: kind: str | None # 'wq_b' or 'wkv_b' (MLA) / 'wq','wk' (MHA) indices: list[int] # clipping 대상 head indices head_dim: int # 기존 MHA용 (uniform stride) threshold: float logit: torch.Tensor | None # MLA 전용 필드 is_mla: bool = False qk_nope_head_dim: int = 0 qk_rope_head_dim: int = 0 v_head_dim: int = 0 ``` ### 수도코드: get_qk_clip_info (MLA 확장) ```python def get_qk_clip_info(clip_config, n, qk_logits): if clip_config is None: return None threshold = clip_config['threshold'] kind, layer_idx = parse_qk_layer(n) is_mla = clip_config.get('is_mla', False) logit, indices = None, [] if qk_logits is not None and kind is not None: logit = qk_logits[layer_idx] if isinstance(logit, DTensor): logit = logit.full_tensor() if kind in ('wq_b', 'wq', 'q_proj'): indices = clip_config.get('q_indices', []) or [] elif kind in ('wkv_b', 'wk', 'k_proj'): indices = clip_config.get('k_indices', []) or [] if is_mla: return QKClipInfo( kind=kind, indices=indices, head_dim=clip_config['head_dim'], # qk_head_dim (for wq_b) threshold=threshold, logit=logit, is_mla=True, qk_nope_head_dim=clip_config['qk_nope_head_dim'], qk_rope_head_dim=clip_config['qk_rope_head_dim'], v_head_dim=clip_config['v_head_dim'], ) else: # 기존 MHA/GQA 경로 return QKClipInfo( kind=kind, indices=indices, head_dim=clip_config['head_dim'], threshold=threshold, logit=logit, ) ``` ### 수도코드: compute_scales (MLA 확장) 기존과 동일하게 per-head γ를 계산한다. (γ 결정은 MHA와 동일) 달라지는 건 `qk_clip` 적용 시 head 내부를 sub-region별로 나눠서 다른 변환을 쓰는 것이다. ```python def compute_scales(p, qk_clip_state): """기존 코드와 동일. per-head √γ 반환.""" kind = qk_clip_state.kind indices = qk_clip_state.indices threshold = qk_clip_state.threshold logit = qk_clip_state.logit head_scales = {} for logit_idx, head_idx in enumerate(indices): v_ele = float(logit[logit_idx]) if v_ele > threshold: new_scale = math.sqrt(threshold / v_ele) # √γ if head_idx not in head_scales or new_scale < head_scales[head_idx]: head_scales[head_idx] = new_scale if not head_scales: return None H_global = p.shape[0] // qk_clip_state.head_dim # MLA: head_dim = qk_head_dim or kv_stride scales_full = torch.ones(H_global, device=p.data.device) for head_idx, scale in head_scales.items(): scales_full[head_idx] = scale # √γ_h return scales_full ``` ### 수도코드: qk_clip (MLA 확장) per-head scales(√γ)는 동일하게 받되, head 내부 sub-region에 다른 함수를 적용한다. ```python def qk_clip(p, scales, head_dim, is_mla=False, kind=None, info=None): """ scales: [n_heads] 텐서, 각 원소 = √γ_h is_mla=False: 기존 MHA/GQA (head 내 uniform √γ) is_mla=True: MLA (head 내 sub-region별 다른 변환) """ W = p.data if isinstance(p, torch.nn.Parameter) else p if not is_mla: # 기존: 모든 행에 √γ 균일 적용 W.view(-1, head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) return # MLA: head별로 sub-region 분리 적용 if kind == 'wq_b': qk_nope = info.qk_nope_head_dim qk_rope = info.qk_rope_head_dim qk_head_dim = qk_nope + qk_rope for h in range(len(scales)): sqrt_gamma = scales[h].item() if sqrt_gamma >= 1.0: continue gamma = sqrt_gamma * sqrt_gamma # √γ → γ s = h * qk_head_dim W[s : s + qk_nope] *= sqrt_gamma # q_nope → √γ W[s + qk_nope : s + qk_head_dim] *= gamma # q_pe → γ elif kind == 'wkv_b': qk_nope = info.qk_nope_head_dim kv_stride = qk_nope + info.v_head_dim for h in range(len(scales)): sqrt_gamma = scales[h].item() if sqrt_gamma >= 1.0: continue s = h * kv_stride W[s : s + qk_nope] *= sqrt_gamma # k_nope → √γ # v 행: 안 건드림 ``` ### 수도코드: GQA에서 wkv_b indices 처리 Q head → KV head 매핑이 필요하다. 여러 Q head가 같은 KV head를 공유하므로, **group 내 최소 gamma** 기준으로 한 번만 적용해야 한다. ```python def build_k_indices_for_mla(clip_config, n_heads, n_kv_heads): """ Q head 기준 logit으로부터 KV head indices를 생성한다. q_indices가 Q head index 기준이라면, k_indices는 대응되는 KV head index로 변환해야 한다. 주의: 같은 KV head에 매핑되는 여러 Q head 중 가장 큰 logit (= 가장 작은 gamma)을 사용해야 한다. """ heads_per_kv = n_heads // n_kv_heads q_indices = clip_config.get('q_indices', list(range(n_heads))) # Q head → KV head 매핑 # logit 텐서에서 같은 kv_head에 대응되는 Q head들 중 max를 취하는 것은 # compute_scales_mla 내부에서 min(gamma) 로 처리됨 k_indices = [] seen = set() for q_idx in q_indices: kv_idx = q_idx // heads_per_kv if kv_idx not in seen: k_indices.append(kv_idx) seen.add(kv_idx) return k_indices ``` ### 수도코드: 호출 흐름 (통합) ```python # optimizer step 이후 호출되는 부분 (기존 코드 구조 유지) for name, param in model.named_parameters(): info = get_qk_clip_info(clip_config, name, qk_logits) if info is None or info.kind is None: continue scales = compute_scales(param, info) # per-head √γ (MHA/MLA 공통) if scales is not None: qk_clip(param, scales, info.head_dim, is_mla=info.is_mla, kind=info.kind, info=info) ``` ### 수도코드: clip_config 예시 ```python # MHA/GQA (기존) clip_config = { 'head_dim': 128, 'threshold': 100.0, 'q_indices': list(range(n_heads)), 'k_indices': list(range(n_kv_heads)), } # MLA (확장) clip_config = { 'is_mla': True, 'head_dim': 192, # qk_head_dim (= qk_nope + qk_rope) 'qk_nope_head_dim': 128, 'qk_rope_head_dim': 64, 'v_head_dim': 128, 'threshold': 100.0, 'q_indices': list(range(n_heads)), 'k_indices': list(range(n_kv_heads)), # build_k_indices_for_mla로 생성 } ``` ## 행 인덱스 매핑 테이블 | 알고리즘 기호 | 텐서 | 행 범위 | scale | |---|---|---|---| | W^h_qc | `wq_b.weight` | `[h*qk_head_dim : h*qk_head_dim + qk_nope_head_dim]` | √γ | | W^h_qr | `wq_b.weight` | `[h*qk_head_dim + qk_nope_head_dim : (h+1)*qk_head_dim]` | γ | | W^h_kc | `wkv_b.weight` | `[kv_h*kv_stride : kv_h*kv_stride + qk_nope_head_dim]` | √γ | | k_R | `wkv_a` output 뒷부분 | - | 안 건드림 | - `kv_stride = qk_nope_head_dim + v_head_dim` - `kv_h = h // (n_heads // n_kv_heads)` (GQA head 매핑) ## 하이퍼파라미터 | 파라미터 | 값 | 비고 | |---|---|---| | τ (threshold) | 100 | K2 full-scale 학습 | | τ (aggressive) | 30 | 소규모 ablation, 성능 저하 없음 확인 | ## 참고사항 - **Self-deactivation**: K2에서 초기 70k step 동안 12.7%의 head만 trigger됨. 이후 모든 head의 S_max가 τ 아래로 내려가면서 자연스럽게 비활성화. - **DP/TP 환경**: S^h_max를 all-reduce로 모든 rank에서 max 수집 필요. - **GQA 중복 적용 방지**: 같은 KV head를 공유하는 Q head group에서 가장 작은 gamma(= 가장 큰 logit)를 기준으로 KV weight를 한 번만 scaling. `compute_scales_mla`에서 `min(gamma)` 로직으로 처리. - **wq_b_gate**: attention logit이 아닌 output gate에만 관여하므로 QK-Clip 대상 아님. - **기존 logit soft-cap**: forward-level safety net으로 남겨두되, optimizer-level QK-Clip을 추가하는 것이 논문의 접근법.