Kernels
optimizer / docs /muon-clip.md
dongseokmotif's picture
feat: extend QK-Clip to support MLA (MuonClip Algorithm 1) [skip-build] (#28)
e8e2c81 unverified

QK-Clip for MuonClip Optimizer (MLA)

Reference: Kimi K2 Technical Report, Section 2.1, Algorithm 1

๊ฐœ์š”

QK-Clip์€ Muon optimizer์—์„œ ๋ฐœ์ƒํ•˜๋Š” attention logit explosion์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•œ weight rescaling ๊ธฐ๋ฒ•์ด๋‹ค. forward/backward์—๋Š” ๊ฐœ์ž…ํ•˜์ง€ ์•Š๊ณ , optimizer step ์ดํ›„์— weight๋ฅผ rescaleํ•˜์—ฌ logit ์„ฑ์žฅ์„ ์›์ฒœ ์ฐจ๋‹จํ•œ๋‹ค.

Algorithm 1: MuonClip

for each training step t:
    // 1. Muon optimizer step
    for each weight W:
        Mt = ยตยทMt-1 + Gt
        Ot = Newton-Schulz(Mt) ยท โˆšmax(n,m) ยท 0.2
        Wt = Wt-1 - ฮทยท(Ot + ฮปยทWt-1)

    // 2. QK-Clip
    for each attention head h:
        S^h_max โ† forward์—์„œ ๊ธฐ๋กํ•œ head h์˜ max pre-softmax logit
        if S^h_max > ฯ„:
            ฮณ โ† ฯ„ / S^h_max
            W^h_qc โ† W^h_qc ยท โˆšฮณ      (query compressed, q_nope)
            W^h_kc โ† W^h_kc ยท โˆšฮณ      (key compressed, k_nope)
            W^h_qr โ† W^h_qr ยท ฮณ       (query rotary, q_pe)
            // k_R (shared rotary, k_pe): ์•ˆ ๊ฑด๋“œ๋ฆผ

๊ธฐ์กด ์ฝ”๋“œ โ†’ MLA ์ˆ˜๋„์ฝ”๋“œ

ํ˜„์žฌ ์ฝ”๋“œ ๊ตฌ์กฐ (MHA/GQA)

parse_qk_layer(name)          โ†’ wq/wk ์—ฌ๋ถ€ ํŒ๋ณ„, layer index ์ถ”์ถœ
get_qk_clip_info(config, n)   โ†’ QKClipInfo (kind, indices, head_dim, threshold, logit)
compute_scales(p, info)        โ†’ per-head โˆšฮณ scales ํ…์„œ ๋ฐ˜ํ™˜
qk_clip(p, scales, head_dim)  โ†’ W.view(-1, head_dim, in_dim).mul_(scales)

ํ˜„์žฌ ์ฝ”๋“œ๋Š” head_dim์ด ๊ท ์ผํ•˜๊ณ , Q/K weight ์ „์ฒด์— ๋™์ผํ•œ โˆšฮณ๋ฅผ ์ ์šฉํ•œ๋‹ค.

MLA์—์„œ ๋‹ฌ๋ผ์ง€๋Š” ์ 

ํ•ญ๋ชฉ MHA/GQA (ํ˜„์žฌ) MLA
Q weight wq / q_proj wq_b (up-proj from LoRA)
K weight wk / k_proj wkv_b (k_nope + v ํ•ฉ์ณ์ ธ ์žˆ์Œ)
Q head stride qk_head_dim (๊ท ์ผ) qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
K head stride qk_head_dim (๊ท ์ผ) kv_stride = qk_nope_head_dim + v_head_dim
Q scaling ์ „์ฒด โˆšฮณ nope โ†’ โˆšฮณ, rope โ†’ ฮณ (์„œ๋กœ ๋‹ค๋ฆ„)
K scaling ์ „์ฒด โˆšฮณ k_nope โ†’ โˆšฮณ, v โ†’ 1.0 (๋ถ€๋ถ„๋งŒ)
shared k_pe ์—†์Œ wkv_a ๋’ท๋ถ€๋ถ„, ์•ˆ ๊ฑด๋“œ๋ฆผ

์ˆ˜๋„์ฝ”๋“œ: parse_qk_layer (MLA ํ™•์žฅ)

def parse_qk_layer(name: str) -> tuple[str | None, int]:
    parts = normalize_fqn(name).split('.')
    kind = parts[-2]

    layer_idx = -1
    for part in reversed(parts):
        if part.isdigit():
            layer_idx = int(part)
            break

    # MHA/GQA: wq, wk, q_proj, k_proj
    # MLA:     wq_b (Q up-proj), wkv_b (KV up-proj)
    if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
        return kind, layer_idx

    return None, -1

์ˆ˜๋„์ฝ”๋“œ: QKClipInfo (MLA ํ™•์žฅ)

@dataclass
class QKClipInfo:
    kind: str | None          # 'wq_b' or 'wkv_b' (MLA) / 'wq','wk' (MHA)
    indices: list[int]        # clipping ๋Œ€์ƒ head indices
    head_dim: int             # ๊ธฐ์กด MHA์šฉ (uniform stride)
    threshold: float
    logit: torch.Tensor | None

    # MLA ์ „์šฉ ํ•„๋“œ
    is_mla: bool = False
    qk_nope_head_dim: int = 0
    qk_rope_head_dim: int = 0
    v_head_dim: int = 0

์ˆ˜๋„์ฝ”๋“œ: get_qk_clip_info (MLA ํ™•์žฅ)

def get_qk_clip_info(clip_config, n, qk_logits):
    if clip_config is None:
        return None

    threshold = clip_config['threshold']
    kind, layer_idx = parse_qk_layer(n)
    is_mla = clip_config.get('is_mla', False)

    logit, indices = None, []
    if qk_logits is not None and kind is not None:
        logit = qk_logits[layer_idx]
        if isinstance(logit, DTensor):
            logit = logit.full_tensor()

        if kind in ('wq_b', 'wq', 'q_proj'):
            indices = clip_config.get('q_indices', []) or []
        elif kind in ('wkv_b', 'wk', 'k_proj'):
            indices = clip_config.get('k_indices', []) or []

    if is_mla:
        return QKClipInfo(
            kind=kind,
            indices=indices,
            head_dim=clip_config['head_dim'],          # qk_head_dim (for wq_b)
            threshold=threshold,
            logit=logit,
            is_mla=True,
            qk_nope_head_dim=clip_config['qk_nope_head_dim'],
            qk_rope_head_dim=clip_config['qk_rope_head_dim'],
            v_head_dim=clip_config['v_head_dim'],
        )
    else:
        # ๊ธฐ์กด MHA/GQA ๊ฒฝ๋กœ
        return QKClipInfo(
            kind=kind, indices=indices,
            head_dim=clip_config['head_dim'],
            threshold=threshold, logit=logit,
        )

์ˆ˜๋„์ฝ”๋“œ: compute_scales (MLA ํ™•์žฅ)

๊ธฐ์กด๊ณผ ๋™์ผํ•˜๊ฒŒ per-head ฮณ๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค. (ฮณ ๊ฒฐ์ •์€ MHA์™€ ๋™์ผ) ๋‹ฌ๋ผ์ง€๋Š” ๊ฑด qk_clip ์ ์šฉ ์‹œ head ๋‚ด๋ถ€๋ฅผ sub-region๋ณ„๋กœ ๋‚˜๋ˆ ์„œ ๋‹ค๋ฅธ ๋ณ€ํ™˜์„ ์“ฐ๋Š” ๊ฒƒ์ด๋‹ค.

def compute_scales(p, qk_clip_state):
    """๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ. per-head โˆšฮณ ๋ฐ˜ํ™˜."""
    kind = qk_clip_state.kind
    indices = qk_clip_state.indices
    threshold = qk_clip_state.threshold
    logit = qk_clip_state.logit

    head_scales = {}
    for logit_idx, head_idx in enumerate(indices):
        v_ele = float(logit[logit_idx])
        if v_ele > threshold:
            new_scale = math.sqrt(threshold / v_ele)      # โˆšฮณ
            if head_idx not in head_scales or new_scale < head_scales[head_idx]:
                head_scales[head_idx] = new_scale

    if not head_scales:
        return None

    H_global = p.shape[0] // qk_clip_state.head_dim      # MLA: head_dim = qk_head_dim or kv_stride
    scales_full = torch.ones(H_global, device=p.data.device)
    for head_idx, scale in head_scales.items():
        scales_full[head_idx] = scale                     # โˆšฮณ_h

    return scales_full

์ˆ˜๋„์ฝ”๋“œ: qk_clip (MLA ํ™•์žฅ)

per-head scales(โˆšฮณ)๋Š” ๋™์ผํ•˜๊ฒŒ ๋ฐ›๋˜, head ๋‚ด๋ถ€ sub-region์— ๋‹ค๋ฅธ ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•œ๋‹ค.

def qk_clip(p, scales, head_dim, is_mla=False, kind=None, info=None):
    """
    scales: [n_heads] ํ…์„œ, ๊ฐ ์›์†Œ = โˆšฮณ_h

    is_mla=False: ๊ธฐ์กด MHA/GQA (head ๋‚ด uniform โˆšฮณ)
    is_mla=True:  MLA (head ๋‚ด sub-region๋ณ„ ๋‹ค๋ฅธ ๋ณ€ํ™˜)
    """
    W = p.data if isinstance(p, torch.nn.Parameter) else p

    if not is_mla:
        # ๊ธฐ์กด: ๋ชจ๋“  ํ–‰์— โˆšฮณ ๊ท ์ผ ์ ์šฉ
        W.view(-1, head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
        return

    # MLA: head๋ณ„๋กœ sub-region ๋ถ„๋ฆฌ ์ ์šฉ
    if kind == 'wq_b':
        qk_nope = info.qk_nope_head_dim
        qk_rope = info.qk_rope_head_dim
        qk_head_dim = qk_nope + qk_rope

        for h in range(len(scales)):
            sqrt_gamma = scales[h].item()
            if sqrt_gamma >= 1.0:
                continue
            gamma = sqrt_gamma * sqrt_gamma      # โˆšฮณ โ†’ ฮณ
            s = h * qk_head_dim

            W[s : s + qk_nope]              *= sqrt_gamma   # q_nope โ†’ โˆšฮณ
            W[s + qk_nope : s + qk_head_dim] *= gamma       # q_pe   โ†’ ฮณ

    elif kind == 'wkv_b':
        qk_nope = info.qk_nope_head_dim
        kv_stride = qk_nope + info.v_head_dim

        for h in range(len(scales)):
            sqrt_gamma = scales[h].item()
            if sqrt_gamma >= 1.0:
                continue
            s = h * kv_stride

            W[s : s + qk_nope] *= sqrt_gamma                # k_nope โ†’ โˆšฮณ
            # v ํ–‰: ์•ˆ ๊ฑด๋“œ๋ฆผ

์ˆ˜๋„์ฝ”๋“œ: GQA์—์„œ wkv_b indices ์ฒ˜๋ฆฌ

Q head โ†’ KV head ๋งคํ•‘์ด ํ•„์š”ํ•˜๋‹ค. ์—ฌ๋Ÿฌ Q head๊ฐ€ ๊ฐ™์€ KV head๋ฅผ ๊ณต์œ ํ•˜๋ฏ€๋กœ, group ๋‚ด ์ตœ์†Œ gamma ๊ธฐ์ค€์œผ๋กœ ํ•œ ๋ฒˆ๋งŒ ์ ์šฉํ•ด์•ผ ํ•œ๋‹ค.

def build_k_indices_for_mla(clip_config, n_heads, n_kv_heads):
    """
    Q head ๊ธฐ์ค€ logit์œผ๋กœ๋ถ€ํ„ฐ KV head indices๋ฅผ ์ƒ์„ฑํ•œ๋‹ค.
    q_indices๊ฐ€ Q head index ๊ธฐ์ค€์ด๋ผ๋ฉด,
    k_indices๋Š” ๋Œ€์‘๋˜๋Š” KV head index๋กœ ๋ณ€ํ™˜ํ•ด์•ผ ํ•œ๋‹ค.

    ์ฃผ์˜: ๊ฐ™์€ KV head์— ๋งคํ•‘๋˜๋Š” ์—ฌ๋Ÿฌ Q head ์ค‘
          ๊ฐ€์žฅ ํฐ logit (= ๊ฐ€์žฅ ์ž‘์€ gamma)์„ ์‚ฌ์šฉํ•ด์•ผ ํ•œ๋‹ค.
    """
    heads_per_kv = n_heads // n_kv_heads
    q_indices = clip_config.get('q_indices', list(range(n_heads)))

    # Q head โ†’ KV head ๋งคํ•‘
    # logit ํ…์„œ์—์„œ ๊ฐ™์€ kv_head์— ๋Œ€์‘๋˜๋Š” Q head๋“ค ์ค‘ max๋ฅผ ์ทจํ•˜๋Š” ๊ฒƒ์€
    # compute_scales_mla ๋‚ด๋ถ€์—์„œ min(gamma) ๋กœ ์ฒ˜๋ฆฌ๋จ

    k_indices = []
    seen = set()
    for q_idx in q_indices:
        kv_idx = q_idx // heads_per_kv
        if kv_idx not in seen:
            k_indices.append(kv_idx)
            seen.add(kv_idx)

    return k_indices

์ˆ˜๋„์ฝ”๋“œ: ํ˜ธ์ถœ ํ๋ฆ„ (ํ†ตํ•ฉ)

# optimizer step ์ดํ›„ ํ˜ธ์ถœ๋˜๋Š” ๋ถ€๋ถ„ (๊ธฐ์กด ์ฝ”๋“œ ๊ตฌ์กฐ ์œ ์ง€)

for name, param in model.named_parameters():
    info = get_qk_clip_info(clip_config, name, qk_logits)
    if info is None or info.kind is None:
        continue

    scales = compute_scales(param, info)    # per-head โˆšฮณ (MHA/MLA ๊ณตํ†ต)
    if scales is not None:
        qk_clip(param, scales, info.head_dim,
                is_mla=info.is_mla, kind=info.kind, info=info)

์ˆ˜๋„์ฝ”๋“œ: clip_config ์˜ˆ์‹œ

# MHA/GQA (๊ธฐ์กด)
clip_config = {
    'head_dim': 128,
    'threshold': 100.0,
    'q_indices': list(range(n_heads)),
    'k_indices': list(range(n_kv_heads)),
}

# MLA (ํ™•์žฅ)
clip_config = {
    'is_mla': True,
    'head_dim': 192,                  # qk_head_dim (= qk_nope + qk_rope)
    'qk_nope_head_dim': 128,
    'qk_rope_head_dim': 64,
    'v_head_dim': 128,
    'threshold': 100.0,
    'q_indices': list(range(n_heads)),
    'k_indices': list(range(n_kv_heads)),  # build_k_indices_for_mla๋กœ ์ƒ์„ฑ
}

ํ–‰ ์ธ๋ฑ์Šค ๋งคํ•‘ ํ…Œ์ด๋ธ”

์•Œ๊ณ ๋ฆฌ์ฆ˜ ๊ธฐํ˜ธ ํ…์„œ ํ–‰ ๋ฒ”์œ„ scale
W^h_qc wq_b.weight [h*qk_head_dim : h*qk_head_dim + qk_nope_head_dim] โˆšฮณ
W^h_qr wq_b.weight [h*qk_head_dim + qk_nope_head_dim : (h+1)*qk_head_dim] ฮณ
W^h_kc wkv_b.weight [kv_h*kv_stride : kv_h*kv_stride + qk_nope_head_dim] โˆšฮณ
k_R wkv_a output ๋’ท๋ถ€๋ถ„ - ์•ˆ ๊ฑด๋“œ๋ฆผ
  • kv_stride = qk_nope_head_dim + v_head_dim
  • kv_h = h // (n_heads // n_kv_heads) (GQA head ๋งคํ•‘)

ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ

ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐ’ ๋น„๊ณ 
ฯ„ (threshold) 100 K2 full-scale ํ•™์Šต
ฯ„ (aggressive) 30 ์†Œ๊ทœ๋ชจ ablation, ์„ฑ๋Šฅ ์ €ํ•˜ ์—†์Œ ํ™•์ธ

์ฐธ๊ณ ์‚ฌํ•ญ

  • Self-deactivation: K2์—์„œ ์ดˆ๊ธฐ 70k step ๋™์•ˆ 12.7%์˜ head๋งŒ trigger๋จ. ์ดํ›„ ๋ชจ๋“  head์˜ S_max๊ฐ€ ฯ„ ์•„๋ž˜๋กœ ๋‚ด๋ ค๊ฐ€๋ฉด์„œ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ๋น„ํ™œ์„ฑํ™”.
  • DP/TP ํ™˜๊ฒฝ: S^h_max๋ฅผ all-reduce๋กœ ๋ชจ๋“  rank์—์„œ max ์ˆ˜์ง‘ ํ•„์š”.
  • GQA ์ค‘๋ณต ์ ์šฉ ๋ฐฉ์ง€: ๊ฐ™์€ KV head๋ฅผ ๊ณต์œ ํ•˜๋Š” Q head group์—์„œ ๊ฐ€์žฅ ์ž‘์€ gamma(= ๊ฐ€์žฅ ํฐ logit)๋ฅผ ๊ธฐ์ค€์œผ๋กœ KV weight๋ฅผ ํ•œ ๋ฒˆ๋งŒ scaling. compute_scales_mla์—์„œ min(gamma) ๋กœ์ง์œผ๋กœ ์ฒ˜๋ฆฌ.
  • wq_b_gate: attention logit์ด ์•„๋‹Œ output gate์—๋งŒ ๊ด€์—ฌํ•˜๋ฏ€๋กœ QK-Clip ๋Œ€์ƒ ์•„๋‹˜.
  • ๊ธฐ์กด logit soft-cap: forward-level safety net์œผ๋กœ ๋‚จ๊ฒจ๋‘๋˜, optimizer-level QK-Clip์„ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ์ด ๋…ผ๋ฌธ์˜ ์ ‘๊ทผ๋ฒ•.