File size: 10,655 Bytes
0adcbb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
"""Convert an auto-round / GPTQ W4A16 packed HuggingFace checkpoint of DeepSeek-V4
into the MP-sharded local format consumed by `model.py`/`generate.py`.

Packing convention (auto-round → auto_gptq):
- qweight : int32 [in_features // 8, out_features], LSB-first 4-bit packed along dim 0
- qzeros  : int32 [in_features // group_size, out_features // 8], LSB-first 4-bit packed along dim 1
- scales  : fp16  [in_features // group_size, out_features]

Sharding rules per linear:
- ColumnParallel (shard output dim, original `dim=0` in `mapping`):
    qweight  along dim 1; qzeros along dim 1 (must be divisible by 8 first, then by world_size);
    scales   along dim 1.
- RowParallel    (shard input dim, original `dim=1` in `mapping`):
    qweight  along dim 0 (must be divisible by 8 first, then by world_size);
    qzeros   along dim 0 (must be divisible by group_size first, then by world_size);
    scales   along dim 0.

Non-quantised tensors (embed.weight, *.norm.weight, attn_sink, hc_*, ape, gate.bias,
gate.tid2eid, etc.) follow the same rules as the original `convert.py`.
"""

import os
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange

import torch
from safetensors.torch import safe_open, save_file


GROUP_SIZE = 128

# Same name remapping as the original convert.py
mapping = {
    "embed_tokens": ("embed", 0),
    "input_layernorm": ("attn_norm", None),
    "post_attention_layernorm": ("ffn_norm", None),
    "q_proj": ("wq", 0),
    "q_a_proj": ("wq_a", None),
    "q_a_layernorm": ("q_norm", None),
    "q_b_proj": ("wq_b", 0),
    "kv_a_proj_with_mqa": ("wkv_a", None),
    "kv_a_layernorm": ("kv_norm", None),
    "kv_b_proj": ("wkv_b", 0),
    "o_proj": ("wo", 1),
    "gate_proj": ("w1", 0),
    "down_proj": ("w2", 1),
    "up_proj": ("w3", 0),
    "lm_head": ("head", 0),

    # Already-translated names (used by the inference checkpoints we already have)
    "embed": ("embed", 0),
    "wq_a": ("wq_a", None),
    "wq_b": ("wq_b", 0),
    "wkv": ("wkv", None),
    "wo_a": ("wo_a", 0),
    "wo_b": ("wo_b", 1),
    "w1": ("w1", 0),
    "w2": ("w2", 1),
    "w3": ("w3", 0),
    "head": ("head", 0),
    "weights_proj": ("weights_proj", 0),
    # special non-weight keys
    "attn_sink": ("attn_sink", 0),
    "ape": ("ape", None),
    # NOTE: 'gate' is intentionally NOT in this mapping -- the routing gate is a
    # plain nn.Parameter that is replicated on every rank.
}


# Suffixes that mark the three pieces of a packed W4A16 linear.
QUANT_SUFFIXES = (".qweight", ".qzeros", ".scales")


def shard_quant(qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor,
                shard_dim: int, mp: int):
    """Yield (qweight_i, qzeros_i, scales_i) for i in range(mp).

    shard_dim is the *logical* dim of the dequantised weight: 0 == output (column parallel),
    1 == input (row parallel)."""
    out = qweight.size(1)
    in_packed = qweight.size(0)             # in_features // 8
    n_groups = scales.size(0)               # in_features // group_size

    if shard_dim == 0:                      # ColumnParallel: shard along OUTPUT
        assert out % mp == 0, f"out={out} not divisible by mp={mp}"
        # qzeros packs 8 outputs per int32 in dim 1, so need (out/mp) % 8 == 0
        assert (out // mp) % 8 == 0, f"shard {out//mp} of out dim not divisible by 8 (qzeros packing)"
        sh_out = out // mp
        sh_qz_cols = qzeros.size(1) // mp   # == out / 8 / mp
        for i in range(mp):
            yield (
                qweight.narrow(1, i * sh_out, sh_out).contiguous(),
                qzeros.narrow(1, i * sh_qz_cols, sh_qz_cols).contiguous(),
                scales.narrow(1, i * sh_out, sh_out).contiguous(),
            )
    elif shard_dim == 1:                    # RowParallel: shard along INPUT
        # qweight packs 8 inputs per int32 in dim 0, scales/qzeros are per-group on dim 0
        assert in_packed % mp == 0, f"in_packed={in_packed} not divisible by mp={mp}"
        assert n_groups % mp == 0, f"n_groups={n_groups} not divisible by mp={mp}"
        sh_in_packed = in_packed // mp
        sh_groups = n_groups // mp
        for i in range(mp):
            yield (
                qweight.narrow(0, i * sh_in_packed, sh_in_packed).contiguous(),
                qzeros.narrow(0, i * sh_groups, sh_groups).contiguous(),
                scales.narrow(0, i * sh_groups, sh_groups).contiguous(),
            )
    else:
        # Replicate
        for _ in range(mp):
            yield qweight, qzeros, scales


def get_layer_key(name: str):
    """Return the linear-name token (e.g. wq_a, w1, head) used for the rename mapping."""
    parts = name.split(".")
    if name.endswith(QUANT_SUFFIXES):
        return parts[-2]                    # ...x.qweight  -> x
    if name.endswith(".bias") and "gate" in name:
        return "gate"                       # ffn.gate.bias
    if name.endswith(".tid2eid"):
        return "gate"
    if any(k in parts for k in ("hc_attn_fn", "hc_attn_base", "hc_attn_scale",
                                 "hc_ffn_fn", "hc_ffn_base", "hc_ffn_scale",
                                 "hc_head_fn", "hc_head_base", "hc_head_scale",
                                 "attn_sink", "ape")):
        return parts[-1]
    return parts[-2]


def main(hf_ckpt_path, save_path, n_experts, mp):
    torch.set_num_threads(8)
    n_local_experts = n_experts // mp
    state_dicts = [{} for _ in range(mp)]

    # Group all fragments belonging to the same logical linear so we can shard
    # qweight/qzeros/scales together.
    pending: dict[str, dict[str, torch.Tensor]] = {}

    def emit_linear(base_name: str, parts: dict[str, torch.Tensor], shard_dim):
        """Distribute a quantised linear (3 tensors) across `mp` shards."""
        qweight = parts["qweight"]
        qzeros = parts["qzeros"]
        scales = parts["scales"].to(torch.bfloat16)   # store bf16 instead of fp16
        # Expert-local pruning: only the rank that owns this expert keeps the tensors.
        if "experts" in base_name and "shared_experts" not in base_name:
            idx = int(base_name.split(".experts.")[1].split(".")[0])
            owner = idx // n_local_experts
            state_dicts[owner][base_name + ".qweight"] = qweight
            state_dicts[owner][base_name + ".qzeros"] = qzeros
            state_dicts[owner][base_name + ".scales"] = scales
            return
        if shard_dim is None:
            # Replicate across all ranks
            for i in range(mp):
                state_dicts[i][base_name + ".qweight"] = qweight
                state_dicts[i][base_name + ".qzeros"] = qzeros
                state_dicts[i][base_name + ".scales"] = scales
        else:
            for i, (qw, qz, sc) in enumerate(shard_quant(qweight, qzeros, scales, shard_dim, mp)):
                state_dicts[i][base_name + ".qweight"] = qw
                state_dicts[i][base_name + ".qzeros"] = qz
                state_dicts[i][base_name + ".scales"] = sc

    files = sorted(glob(os.path.join(hf_ckpt_path, "*.safetensors")))
    for file_path in tqdm(files, desc="files"):
        with safe_open(file_path, framework="pt", device="cpu") as f:
            for orig_name in f.keys():
                # ----- name remapping (mirrors original convert.py) -----
                name = orig_name
                if name.startswith("model."):
                    name = name[len("model."):]
                if name.startswith("mtp.") and ("emb" in name or name.endswith("head.weight")):
                    continue
                name = name.replace("self_attn", "attn")
                name = name.replace("mlp", "ffn")
                name = name.replace("weight_scale_inv", "scale")
                name = name.replace("e_score_correction_bias", "bias")

                key = get_layer_key(name)
                if key in mapping:
                    new_key, dim = mapping[key]
                    name = name.replace(key, new_key)
                else:
                    dim = None

                tensor = f.get_tensor(orig_name)

                # ----- handle the three-piece quantised linear -----
                # `shared_experts` are plain (non-parallel) Linears in the model;
                # never shard them even though `w1/w2/w3` are in the mapping.
                if "shared_experts" in name:
                    dim = None

                if orig_name.endswith(QUANT_SUFFIXES):
                    base = name.rsplit(".", 1)[0]
                    suf = name.rsplit(".", 1)[1]                # qweight|qzeros|scales
                    pending.setdefault(base, {"_dim": dim})[suf] = tensor
                    pending[base]["_dim"] = dim
                    parts = pending[base]
                    if all(s in parts for s in ("qweight", "qzeros", "scales")):
                        emit_linear(base, parts, parts["_dim"])
                        del pending[base]
                    continue

                # ----- non-quantised tensor -----
                if "experts" in name and "shared_experts" not in name:
                    idx = int(name.split(".experts.")[1].split(".")[0])
                    owner = idx // n_local_experts
                    state_dicts[owner][name] = tensor
                    continue

                if dim is None:
                    for i in range(mp):
                        state_dicts[i][name] = tensor
                else:
                    assert tensor.size(dim) % mp == 0, f"{name} dim {dim} ({tensor.size(dim)}) not divisible by {mp}"
                    sh = tensor.size(dim) // mp
                    for i in range(mp):
                        state_dicts[i][name] = tensor.narrow(dim, i * sh, sh).contiguous()

    if pending:
        raise RuntimeError(f"Incomplete quantised linears: {list(pending)[:5]}")

    os.makedirs(save_path, exist_ok=True)
    for i in trange(mp, desc="write shards"):
        save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))

    for fn in ["tokenizer.json", "tokenizer_config.json"]:
        src = os.path.join(hf_ckpt_path, fn)
        dst = os.path.join(save_path, fn)
        if os.path.exists(src):
            shutil.copyfile(src, dst)


if __name__ == "__main__":
    p = ArgumentParser()
    p.add_argument("--hf-ckpt-path", required=True)
    p.add_argument("--save-path", required=True)
    p.add_argument("--n-experts", type=int, required=True)
    p.add_argument("--model-parallel", type=int, required=True)
    a = p.parse_args()
    assert a.n_experts % a.model_parallel == 0
    main(a.hf_ckpt_path, a.save_path, a.n_experts, a.model_parallel)