| ''' |
| |
| Converts a transformers model to a format compatible with flexgen. |
| |
| ''' |
|
|
| import argparse |
| import os |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from tqdm import tqdm |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) |
| parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") |
| args = parser.parse_args() |
|
|
| def disable_torch_init(): |
| """ |
| Disable the redundant torch default initialization to accelerate model creation. |
| """ |
| import torch |
| global torch_linear_init_backup |
| global torch_layer_norm_init_backup |
|
|
| torch_linear_init_backup = torch.nn.Linear.reset_parameters |
| setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
|
|
| torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters |
| setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
|
|
| def restore_torch_init(): |
| """Rollback the change made by disable_torch_init.""" |
| import torch |
| setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup) |
| setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup) |
|
|
| if __name__ == '__main__': |
| path = Path(args.MODEL) |
| model_name = path.name |
|
|
| print(f"Loading {model_name}...") |
| |
| model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True) |
| |
|
|
| tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
| out_folder = Path(f"models/{model_name}-np") |
| if not Path(out_folder).exists(): |
| os.mkdir(out_folder) |
|
|
| print(f"Saving the converted model to {out_folder}...") |
| for name, param in tqdm(list(model.model.named_parameters())): |
| name = name.replace("decoder.final_layer_norm", "decoder.layer_norm") |
| param_path = os.path.join(out_folder, name) |
| with open(param_path, "wb") as f: |
| np.save(f, param.cpu().detach().numpy()) |
|
|