| |
| |
|
|
| import argparse |
| import io |
| import os |
| import tempfile |
| from datetime import timedelta |
|
|
| import fla |
| import torch |
| import torch.serialization |
| from torch.distributed.checkpoint.format_utils import dcp_to_torch_save |
| from torchtitan.tools.logging import init_logger, logger |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
|
|
| from fla.models import HamiltonForCausalLM as NewModelForCausalLM, HamiltonConfig as NewConfig |
| MODEL_TYPE = NewConfig.model_type |
|
|
| |
|
|
|
|
| @torch.inference_mode() |
| def save_pretrained( |
| path: str, |
| step: int, |
| config: str, |
| tokenizer: str |
| ): |
| logger.info(f"Loading the config from {config}") |
| |
|
|
| AutoConfig.register(MODEL_TYPE, NewConfig) |
| config = AutoConfig.from_pretrained(config, trust_remote_code=True) |
|
|
| logger.info(f"Saving the config to {path}") |
| config.save_pretrained(path) |
| logger.info(f"Loading the tokenizer from {tokenizer}") |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) |
| logger.info(f"Saving the tokenizer to {path}") |
| tokenizer.save_pretrained(path) |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| checkpoint = os.path.join(path, f'checkpoint/step-{step}') |
| checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt') |
| logger.info(f"Saving the distributed checkpoint to {checkpoint_path}") |
| dcp_to_torch_save(checkpoint, checkpoint_path) |
|
|
| logger.info(f"Initializing the model from config\n{config}") |
| |
| AutoModelForCausalLM.register(NewConfig, NewModelForCausalLM) |
| model = AutoModelForCausalLM.from_config(config) |
|
|
| logger.info(model) |
| logger.info("Loading state dict from the checkpoint") |
|
|
| |
| torch.serialization.add_safe_globals([timedelta, io.BytesIO]) |
| |
| model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model']) |
|
|
| logger.info(f"Saving the model to {path}") |
| model.save_pretrained(path) |
|
|
|
|
| if __name__ == "__main__": |
| init_logger() |
| parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.") |
| parser.add_argument("--path", type=str, required=True) |
| parser.add_argument("--step", type=int, required=True) |
| parser.add_argument("--config", type=str, required=True) |
| parser.add_argument("--tokenizer", type=str, required=True) |
| args = parser.parse_args() |
| save_pretrained(args.path, args.step, args.config, args.tokenizer) |
|
|