| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| from argparse import ArgumentParser |
| from collections import OrderedDict |
|
|
| import torch |
| from joblib import Parallel, delayed |
| from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import ( |
| MegatronGPTModel, |
| ) |
| from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy |
| from nemo.utils import logging |
| from pytorch_lightning import Trainer |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| def convert_layer(model, layer: int, torch_dtype) -> dict: |
| converted = {} |
|
|
| param_to_weights = lambda param: param.to(torch_dtype) |
|
|
| hidden_size: int = model.cfg.hidden_size |
| head_num: int = model.cfg.num_attention_heads |
| ffn_hidden_size: int = model.cfg.ffn_hidden_size |
| num_query_groups: int = model.cfg.get("num_query_groups", head_num) |
|
|
| head_size: int = hidden_size // head_num |
| heads_per_group: int = head_num // num_query_groups |
| qkv_total_dim: int = head_num + 2 * num_query_groups |
|
|
| qkv_weights = model.state_dict()[ |
| f"model.module.decoder.layers.{layer}.self_attention.linear_qkv.weight" |
| ] |
| qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) |
|
|
| q_slice = torch.cat( |
| [ |
| torch.arange( |
| (heads_per_group + 2) * i, |
| (heads_per_group + 2) * i + heads_per_group, |
| ) |
| for i in range(num_query_groups) |
| ] |
| ) |
| k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) |
| v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| q_weights_base_name: str = f"model.module.layers.{layer}.self_attn.q_proj.weight" |
| k_weights_base_name: str = f"model.module.layers.{layer}.self_attn.k_proj.weight" |
| v_weights_base_name: str = f"model.module.layers.{layer}.self_attn.v_proj.weight" |
|
|
| converted[q_weights_base_name] = param_to_weights( |
| qkv_weights[q_slice].reshape(-1, hidden_size) |
| ) |
| converted[k_weights_base_name] = param_to_weights( |
| qkv_weights[k_slice].reshape(-1, hidden_size) |
| ) |
| converted[v_weights_base_name] = param_to_weights( |
| qkv_weights[v_slice].reshape(-1, hidden_size) |
| ) |
|
|
| |
| o_weight = model.state_dict()[ |
| f"model.module.decoder.layers.{layer}.self_attention.linear_proj.weight" |
| ] |
| o_weight_base_name = f"model.module.layers.{layer}.self_attn.o_proj.weight" |
| converted[o_weight_base_name] = param_to_weights(o_weight) |
|
|
| |
| mlp_weights = model.state_dict()[ |
| f"model.module.decoder.layers.{layer}.mlp.linear_fc1.weight" |
| ] |
| mlp_down_proj_weight = mlp_weights[:ffn_hidden_size, :] |
| mlp_gate_proj_weight = mlp_weights[ffn_hidden_size:, :] |
|
|
| mlp_down_proj_base_name = f"model.module.layers.{layer}.mlp.gate_proj.weight" |
| mlp_gate_proj_base_name = f"model.module.layers.{layer}.mlp.up_proj.weight" |
|
|
| converted[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight) |
| converted[mlp_gate_proj_base_name] = param_to_weights(mlp_gate_proj_weight) |
|
|
| mlp_up_proj_weight = model.state_dict()[ |
| f"model.module.decoder.layers.{layer}.mlp.linear_fc2.weight" |
| ] |
| mlp_up_proj_base_name = f"model.module.layers.{layer}.mlp.down_proj.weight" |
| converted[mlp_up_proj_base_name] = param_to_weights(mlp_up_proj_weight) |
|
|
| |
| input_ln_weight = model.state_dict()[ |
| f"model.module.decoder.layers.{layer}.self_attention.linear_qkv.layer_norm_weight" |
| ] |
| input_ln_base_name = f"model.module.layers.{layer}.input_layernorm.weight" |
| converted[input_ln_base_name] = param_to_weights(input_ln_weight) |
|
|
| post_attn_ln_weight = model.state_dict()[ |
| f"model.module.decoder.layers.{layer}.mlp.linear_fc1.layer_norm_weight" |
| ] |
| post_attn_ln_base_name: str = ( |
| f"model.module.layers.{layer}.post_attention_layernorm.weight" |
| ) |
| converted[post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight) |
|
|
| logging.info(f"Layer {layer} converted") |
|
|
| return converted |
|
|
|
|
| def convert( |
| input_nemo_file: str, |
| output_hf_file, |
| precision=None, |
| cpu_only=False, |
| n_jobs=1, |
| ) -> torch.dtype: |
| """ |
| Convert NeMo weights to HF weights |
| """ |
| dummy_trainer = Trainer(devices=1, accelerator="cpu", strategy=NLPDDPStrategy()) |
| model_config = MegatronGPTModel.restore_from( |
| input_nemo_file, trainer=dummy_trainer, return_config=True |
| ) |
| model_config.tensor_model_parallel_size = 1 |
| model_config.pipeline_model_parallel_size = 1 |
| model_config.sequence_parallel = False |
| if cpu_only: |
| logging.info( |
| "******** Loading model on CPU. This will take a significant amount of time." |
| ) |
| map_location = torch.device("cpu") |
| model_config.use_cpu_initialization = True |
| else: |
| map_location = None |
|
|
| model = MegatronGPTModel.restore_from( |
| input_nemo_file, |
| trainer=dummy_trainer, |
| override_config_path=model_config, |
| map_location=map_location, |
| ) |
| if precision is None: |
| precision = model.cfg.precision |
| if precision in [32, "32"]: |
| torch_dtype = torch.float32 |
| elif precision in [16, "16", "16-mixed"]: |
| torch_dtype = torch.float16 |
| elif precision in ["bf16", "bf16-mixed"]: |
| torch_dtype = torch.bfloat16 |
| else: |
| logging.warning( |
| f"Precision string {precision} is not recognized, falling back to fp32" |
| ) |
| torch_dtype = torch.float32 |
| logging.info(f"Using precision {torch_dtype}") |
|
|
| checkpoint = OrderedDict() |
|
|
| param_to_weights = lambda param: param.to(torch_dtype) |
|
|
| |
| embed_weight = model.state_dict()["model.module.embedding.word_embeddings.weight"] |
| checkpoint["model.embed_tokens.weight"] = param_to_weights(embed_weight) |
|
|
| |
| num_layers: int = model.cfg.num_layers |
| n_jobs = min(n_jobs, num_layers) |
| converted_layers: list[dict] = Parallel(n_jobs=n_jobs, prefer="threads")( |
| delayed(convert_layer)(model, layer, torch_dtype) for layer in range(num_layers) |
| ) |
| for converted_layer in converted_layers: |
| checkpoint.update(converted_layer) |
|
|
| final_ln_weight = model.state_dict()["model.module.decoder.final_layernorm.weight"] |
| checkpoint["model.module.norm.weight"] = param_to_weights(final_ln_weight) |
| output_layer_weight = model.state_dict()["model.module.output_layer.weight"] |
| checkpoint["lm_head.weight"] = param_to_weights(output_layer_weight) |
|
|
| if model_config.get("megatron_amp_O2", False): |
| keys: list[str] = list(checkpoint.keys()) |
| for key in keys: |
| checkpoint[key.replace("model.module.", "model.", 1)] = checkpoint.pop(key) |
|
|
| os.makedirs(os.path.dirname(output_hf_file), exist_ok=True) |
| torch.save(checkpoint, output_hf_file) |
| logging.info(f"Weights saved to {output_hf_file}") |
|
|
| return torch_dtype |
|
|
|
|
| def replace_hf_weights_and_tokenizer( |
| weights_file: str, |
| torch_dtype: torch.dtype, |
| input_hf_path: str, |
| output_hf_path: str, |
| ) -> None: |
| orig_hf_model = AutoModelForCausalLM.from_pretrained( |
| input_hf_path, torch_dtype=torch_dtype |
| ) |
| nemo_weights = torch.load(weights_file) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(input_hf_path) |
| tokenizer_length: int = len(tokenizer) |
| if tokenizer_length > orig_hf_model.config.vocab_size: |
| logging.warning( |
| f"Tokenizer length {tokenizer_length} does not match model vocab size {orig_hf_model.config.vocab_size}. Resizing token embeddings." |
| ) |
| orig_hf_model.resize_token_embeddings(tokenizer_length) |
|
|
| orig_hf_model.load_state_dict(nemo_weights) |
| orig_hf_model.save_pretrained(output_hf_path) |
| logging.info(f"Full HF model saved to {output_hf_path}") |
|
|
| tokenizer.save_pretrained(output_hf_path) |
| logging.info(f"Tokenizer saved to {output_hf_path}") |
|
|
|
|
| def main(): |
| parser = ArgumentParser() |
| parser.add_argument( |
| "--input-name-or-path", |
| type=str, |
| default=None, |
| required=True, |
| help="Path to .nemo file or extracted folder", |
| ) |
| parser.add_argument( |
| "--input-hf-path", |
| type=str, |
| default="llm-jp/llm-jp-13b-v2.0", |
| help="Path to the HF model directory to use as input", |
| ) |
| parser.add_argument( |
| "--output-path", |
| type=str, |
| default=None, |
| required=True, |
| help="Output HF model directory", |
| ) |
| parser.add_argument( |
| "--model-id", |
| type=str, |
| default=None, |
| required=True, |
| help="Model ID for the Output HF model", |
| ) |
| parser.add_argument( |
| "--precision", |
| type=str, |
| default=None, |
| help="Precision of output weights. Defaults to precision of the input nemo weights (model.cfg.trainer.precision)", |
| ) |
| parser.add_argument( |
| "--cpu-only", |
| action="store_true", |
| help="Load model in cpu only. Useful if the model cannot fit in GPU memory, but this option makes the conversion script significantly slower.", |
| ) |
| parser.add_argument( |
| "--n-jobs", |
| type=int, |
| default=1, |
| help="Number of parallel jobs to run. Default is 1. This is for converting the model layers in parallel. If n_jobs > num_layers, it will be set to num_layers.", |
| ) |
| args = parser.parse_args() |
|
|
| output_hf_path: str = f"{args.output_path}/{args.model_id}" |
| torch_dtype = convert( |
| input_nemo_file=args.input_name_or_path, |
| output_hf_file=f"{output_hf_path}/pytorch_model.bin", |
| precision=args.precision, |
| cpu_only=args.cpu_only, |
| n_jobs=args.n_jobs, |
| ) |
| replace_hf_weights_and_tokenizer( |
| weights_file=f"{output_hf_path}/pytorch_model.bin", |
| torch_dtype=torch_dtype, |
| input_hf_path=args.input_hf_path, |
| output_hf_path=output_hf_path, |
| ) |
| os.remove(f"{output_hf_path}/pytorch_model.bin") |
|
|
|
|
| if __name__ == "__main__": |
| main() |