File size: 10,978 Bytes
dd958be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from argparse import ArgumentParser
from collections import OrderedDict

import torch
from joblib import Parallel, delayed
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import (
    MegatronGPTModel,
)
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.utils import logging
from pytorch_lightning import Trainer
from transformers import AutoModelForCausalLM, AutoTokenizer


def convert_layer(model, layer: int, torch_dtype) -> dict:
    converted = {}

    param_to_weights = lambda param: param.to(torch_dtype)  # noqa

    hidden_size: int = model.cfg.hidden_size
    head_num: int = model.cfg.num_attention_heads
    ffn_hidden_size: int = model.cfg.ffn_hidden_size
    num_query_groups: int = model.cfg.get("num_query_groups", head_num)

    head_size: int = hidden_size // head_num
    heads_per_group: int = head_num // num_query_groups
    qkv_total_dim: int = head_num + 2 * num_query_groups

    qkv_weights = model.state_dict()[
        f"model.module.decoder.layers.{layer}.self_attention.linear_qkv.weight"
    ]
    qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size])

    q_slice = torch.cat(
        [
            torch.arange(
                (heads_per_group + 2) * i,
                (heads_per_group + 2) * i + heads_per_group,
            )
            for i in range(num_query_groups)
        ]
    )
    k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
    v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
    # Example of slices
    # 7b: num_query_groups = head_num = 32,
    # q_slice = [0, 3, 6, 9 , ... 90, 93]
    # k_slice = [1, 4, 7, 10, ... 91, 94]
    # v_slice = [2, 5, 8, 11, ... 92, 95]
    # 70b (with GQA): num_query_groups = 8, head_num = 64
    # q_slice = [0, 1, .. 6, 7, 10, 11, .. 16, 17, 20, 21, .. 67, 70, ... 76, 77]
    # k_slice = [8, 18, 28, ... 68, 78]
    # v_slice = [9, 19, 29, ... 69, 79]

    q_weights_base_name: str = f"model.module.layers.{layer}.self_attn.q_proj.weight"
    k_weights_base_name: str = f"model.module.layers.{layer}.self_attn.k_proj.weight"
    v_weights_base_name: str = f"model.module.layers.{layer}.self_attn.v_proj.weight"

    converted[q_weights_base_name] = param_to_weights(
        qkv_weights[q_slice].reshape(-1, hidden_size)
    )
    converted[k_weights_base_name] = param_to_weights(
        qkv_weights[k_slice].reshape(-1, hidden_size)
    )
    converted[v_weights_base_name] = param_to_weights(
        qkv_weights[v_slice].reshape(-1, hidden_size)
    )

    # attention dense
    o_weight = model.state_dict()[
        f"model.module.decoder.layers.{layer}.self_attention.linear_proj.weight"
    ]
    o_weight_base_name = f"model.module.layers.{layer}.self_attn.o_proj.weight"
    converted[o_weight_base_name] = param_to_weights(o_weight)

    # mlp
    mlp_weights = model.state_dict()[
        f"model.module.decoder.layers.{layer}.mlp.linear_fc1.weight"
    ]
    mlp_down_proj_weight = mlp_weights[:ffn_hidden_size, :]
    mlp_gate_proj_weight = mlp_weights[ffn_hidden_size:, :]

    mlp_down_proj_base_name = f"model.module.layers.{layer}.mlp.gate_proj.weight"
    mlp_gate_proj_base_name = f"model.module.layers.{layer}.mlp.up_proj.weight"

    converted[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight)
    converted[mlp_gate_proj_base_name] = param_to_weights(mlp_gate_proj_weight)

    mlp_up_proj_weight = model.state_dict()[
        f"model.module.decoder.layers.{layer}.mlp.linear_fc2.weight"
    ]
    mlp_up_proj_base_name = f"model.module.layers.{layer}.mlp.down_proj.weight"
    converted[mlp_up_proj_base_name] = param_to_weights(mlp_up_proj_weight)

    # layernorm
    input_ln_weight = model.state_dict()[
        f"model.module.decoder.layers.{layer}.self_attention.linear_qkv.layer_norm_weight"
    ]
    input_ln_base_name = f"model.module.layers.{layer}.input_layernorm.weight"
    converted[input_ln_base_name] = param_to_weights(input_ln_weight)

    post_attn_ln_weight = model.state_dict()[
        f"model.module.decoder.layers.{layer}.mlp.linear_fc1.layer_norm_weight"
    ]
    post_attn_ln_base_name: str = (
        f"model.module.layers.{layer}.post_attention_layernorm.weight"
    )
    converted[post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight)

    logging.info(f"Layer {layer} converted")

    return converted


def convert(
    input_nemo_file: str,
    output_hf_file,
    precision=None,
    cpu_only=False,
    n_jobs=1,
) -> torch.dtype:
    """
    Convert NeMo weights to HF weights
    """
    dummy_trainer = Trainer(devices=1, accelerator="cpu", strategy=NLPDDPStrategy())
    model_config = MegatronGPTModel.restore_from(
        input_nemo_file, trainer=dummy_trainer, return_config=True
    )
    model_config.tensor_model_parallel_size = 1
    model_config.pipeline_model_parallel_size = 1
    model_config.sequence_parallel = False
    if cpu_only:
        logging.info(
            "******** Loading model on CPU. This will take a significant amount of time."
        )
        map_location = torch.device("cpu")
        model_config.use_cpu_initialization = True
    else:
        map_location = None

    model = MegatronGPTModel.restore_from(
        input_nemo_file,
        trainer=dummy_trainer,
        override_config_path=model_config,
        map_location=map_location,
    )
    if precision is None:
        precision = model.cfg.precision
    if precision in [32, "32"]:
        torch_dtype = torch.float32
    elif precision in [16, "16", "16-mixed"]:
        torch_dtype = torch.float16
    elif precision in ["bf16", "bf16-mixed"]:
        torch_dtype = torch.bfloat16
    else:
        logging.warning(
            f"Precision string {precision} is not recognized, falling back to fp32"
        )
        torch_dtype = torch.float32  # fallback
    logging.info(f"Using precision {torch_dtype}")

    checkpoint = OrderedDict()

    param_to_weights = lambda param: param.to(torch_dtype)  # noqa

    # Embedding
    embed_weight = model.state_dict()["model.module.embedding.word_embeddings.weight"]
    checkpoint["model.embed_tokens.weight"] = param_to_weights(embed_weight)

    # Layers
    num_layers: int = model.cfg.num_layers
    n_jobs = min(n_jobs, num_layers)
    converted_layers: list[dict] = Parallel(n_jobs=n_jobs, prefer="threads")(
        delayed(convert_layer)(model, layer, torch_dtype) for layer in range(num_layers)
    )
    for converted_layer in converted_layers:
        checkpoint.update(converted_layer)

    final_ln_weight = model.state_dict()["model.module.decoder.final_layernorm.weight"]
    checkpoint["model.module.norm.weight"] = param_to_weights(final_ln_weight)
    output_layer_weight = model.state_dict()["model.module.output_layer.weight"]
    checkpoint["lm_head.weight"] = param_to_weights(output_layer_weight)

    if model_config.get("megatron_amp_O2", False):
        keys: list[str] = list(checkpoint.keys())
        for key in keys:
            checkpoint[key.replace("model.module.", "model.", 1)] = checkpoint.pop(key)

    os.makedirs(os.path.dirname(output_hf_file), exist_ok=True)
    torch.save(checkpoint, output_hf_file)
    logging.info(f"Weights saved to {output_hf_file}")

    return torch_dtype


def replace_hf_weights_and_tokenizer(
    weights_file: str,
    torch_dtype: torch.dtype,
    input_hf_path: str,
    output_hf_path: str,
) -> None:
    orig_hf_model = AutoModelForCausalLM.from_pretrained(
        input_hf_path, torch_dtype=torch_dtype
    )
    nemo_weights = torch.load(weights_file)

    tokenizer = AutoTokenizer.from_pretrained(input_hf_path)
    tokenizer_length: int = len(tokenizer)
    if tokenizer_length > orig_hf_model.config.vocab_size:
        logging.warning(
            f"Tokenizer length {tokenizer_length} does not match model vocab size {orig_hf_model.config.vocab_size}. Resizing token embeddings."
        )
        orig_hf_model.resize_token_embeddings(tokenizer_length)

    orig_hf_model.load_state_dict(nemo_weights)
    orig_hf_model.save_pretrained(output_hf_path)
    logging.info(f"Full HF model saved to {output_hf_path}")

    tokenizer.save_pretrained(output_hf_path)
    logging.info(f"Tokenizer saved to {output_hf_path}")


def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--input-name-or-path",
        type=str,
        default=None,
        required=True,
        help="Path to .nemo file or extracted folder",
    )
    parser.add_argument(
        "--input-hf-path",
        type=str,
        default="llm-jp/llm-jp-13b-v2.0",
        help="Path to the HF model directory to use as input",
    )
    parser.add_argument(
        "--output-path",
        type=str,
        default=None,
        required=True,
        help="Output HF model directory",
    )
    parser.add_argument(
        "--model-id",
        type=str,
        default=None,
        required=True,
        help="Model ID for the Output HF model",
    )
    parser.add_argument(
        "--precision",
        type=str,
        default=None,
        help="Precision of output weights. Defaults to precision of the input nemo weights (model.cfg.trainer.precision)",
    )
    parser.add_argument(
        "--cpu-only",
        action="store_true",
        help="Load model in cpu only. Useful if the model cannot fit in GPU memory, but this option makes the conversion script significantly slower.",
    )
    parser.add_argument(
        "--n-jobs",
        type=int,
        default=1,
        help="Number of parallel jobs to run. Default is 1. This is for converting the model layers in parallel. If n_jobs > num_layers, it will be set to num_layers.",
    )
    args = parser.parse_args()

    output_hf_path: str = f"{args.output_path}/{args.model_id}"
    torch_dtype = convert(
        input_nemo_file=args.input_name_or_path,
        output_hf_file=f"{output_hf_path}/pytorch_model.bin",
        precision=args.precision,
        cpu_only=args.cpu_only,
        n_jobs=args.n_jobs,
    )
    replace_hf_weights_and_tokenizer(
        weights_file=f"{output_hf_path}/pytorch_model.bin",
        torch_dtype=torch_dtype,
        input_hf_path=args.input_hf_path,
        output_hf_path=output_hf_path,
    )
    os.remove(f"{output_hf_path}/pytorch_model.bin")


if __name__ == "__main__":
    main()