| import os |
| import shutil |
|
|
|
|
| def save_checkpoint(accelerator, args, global_step, logger): |
| """Save model checkpoint and enforce checkpoints_total_limit.""" |
| save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") |
| accelerator.save_state(save_path) |
| logger.info(f"Saved state to {save_path}") |
|
|
| if args.checkpoints_total_limit is not None: |
| checkpoints = os.listdir(args.output_dir) |
| checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] |
| checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) |
|
|
| if len(checkpoints) >= args.checkpoints_total_limit: |
| num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 |
| removing_checkpoints = checkpoints[:num_to_remove] |
|
|
| logger.info( |
| f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" |
| ) |
| logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") |
|
|
| for removing_checkpoint in removing_checkpoints: |
| checkpoint_path = os.path.join(args.output_dir, removing_checkpoint) |
| if os.path.exists(checkpoint_path): |
| try: |
| shutil.rmtree(checkpoint_path) |
| logger.info(f'Directory "{checkpoint_path}" has been removed.') |
| except Exception as e: |
| logger.info(f'Error removing directory "{checkpoint_path}": {e}. Continuing with the next item.') |
| else: |
| logger.info(f'Directory "{checkpoint_path}" does not exist.') |
|
|