| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Example: python scripts/checkpoint_averaging/zarr_distributed_checkpoint_averaging.py \ |
| --name_prefix=<checkpoint name> \ |
| --checkpoint_dir=<folder containing checkpoints> \ |
| --steps <list of checkpoint steps to average, if not provided, it will average all the checkpoints> |
| |
| will generate a new directory in each of the distributed checkpoint subfolders named <checkpoint name>-averaged |
| """ |
|
|
| import argparse |
| import logging |
| import os |
| import shutil |
| import numpy as np |
| import zarr |
|
|
| logging.basicConfig(level=logging.INFO) |
|
|
|
|
| def main(): |
| """ |
| Main function |
| """ |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| '--name_prefix', |
| help='Name of the final checkpoint. Will append -averaged automatically.', |
| ) |
| parser.add_argument( |
| '--checkpoint_dir', |
| help='Folder containing all the distributed checkpoints.', |
| ) |
| |
| parser.add_argument( |
| '--steps', |
| nargs='+', |
| type=int, |
| help='List of checkpoint steps to average. If not specified, will average all.', |
| ) |
|
|
| args = parser.parse_args() |
|
|
| if args.steps is not None: |
| logging.info(f"Will average only steps {args.steps}") |
|
|
| |
|
|
| checkpoint_paths = [] |
| for ckpt_dir in os.listdir(args.checkpoint_dir): |
| logging.info("Processing %s", ckpt_dir) |
| if ckpt_dir.endswith('0-last'): |
| continue |
| if args.steps is None: |
| checkpoint_paths.append(ckpt_dir) |
| else: |
| for step in args.steps: |
| key = f"-step={step}-" |
| if key in ckpt_dir: |
| checkpoint_paths.append(ckpt_dir) |
|
|
| n = len(checkpoint_paths) |
| |
| avg_weights = {} |
| chunk_info = {} |
|
|
| logging.info(f"Averaging {n} checkpoints ... {'at steps:' + str(args.steps) if args.steps is not None else ''}") |
|
|
| |
| copy_items = [] |
| for ix, path in enumerate(checkpoint_paths): |
| full_path = os.path.join(args.checkpoint_dir, path) |
|
|
| for item in os.listdir(full_path): |
|
|
| |
| if not os.path.isdir(os.path.join(full_path, item)): |
| if ix == 0: |
| copy_items.append(os.path.join(full_path, item)) |
| continue |
|
|
| |
| if item.endswith('._extra_state'): |
| if ix == 0: |
| copy_items.append(os.path.join(full_path, item)) |
| continue |
|
|
| |
| if item.startswith('optimizer.'): |
| if ix == 0: |
| copy_items.append(os.path.join(full_path, item)) |
| continue |
|
|
| if item not in avg_weights: |
| logging.info(f"Initialized average weights dict with: {item}") |
| array = zarr.open(os.path.join(full_path, item), mode='r') |
| avg_weights[item] = array[:] |
| chunk_info[item] = array.chunks |
| else: |
| logging.info(f"Updated average weights dict with weight: {item}") |
| array_z = zarr.open(os.path.join(full_path, item), mode='r') |
| sum_array = avg_weights[item] + array_z[:] |
| avg_weights[item] = sum_array |
|
|
| for k in avg_weights: |
| logging.info(f"Average weights dict key : {k}, dtype : {avg_weights[k].dtype}, shape : {avg_weights[k].shape}") |
| if str(avg_weights[k].dtype).startswith("int"): |
| raise ValueError("Int type not supported") |
| else: |
| array_z = avg_weights[k] / n |
| avg_weights[k] = array_z |
|
|
| |
| if args.steps is None: |
| ckpt_name = os.path.join(args.checkpoint_dir, args.name_prefix + '-averaged') |
| else: |
| steps_combined = '_'.join([str(x) for x in args.steps]) |
| ckpt_name = os.path.join(args.checkpoint_dir, args.name_prefix + '-' + steps_combined + '-averaged') |
|
|
| |
| for k in avg_weights: |
| logging.info(f"Saving {k} to {ckpt_name}") |
| input_arr = avg_weights[k] |
| chunks = chunk_info[k] |
| |
| output_array = zarr.create( |
| input_arr.shape, |
| dtype=input_arr.dtype, |
| store=os.path.join(ckpt_name, k), |
| chunks=chunks, |
| compressor=None, |
| fill_value=None, |
| write_empty_chunks=True, |
| ) |
| if input_arr.dtype == np.dtype('bfloat16'): |
| arr = output_array |
| arr._dtype = input_arr.dtype |
| zarray = arr.store['.zarray'] |
| arr.store['.zarray'] = zarray.replace(b'<V2', b'bfloat16') |
| output_array[:] = input_arr |
|
|
| |
| for item in copy_items: |
| is_file = os.path.isfile(item) |
| logging.info(f"Copying {'directory' if is_file else 'file'} {item} to {ckpt_name}") |
| if os.path.isfile(item): |
| |
| shutil.copy(item, ckpt_name) |
| else: |
| |
| shutil.copytree(item, os.path.join(ckpt_name, os.path.basename(item)), dirs_exist_ok=True) |
|
|
| logging.info(f"Averaged distributed checkpoint saved as : {ckpt_name}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|