File size: 1,650 Bytes
874cec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
import shutil


def save_checkpoint(accelerator, args, global_step, logger):
    """Save model checkpoint and enforce checkpoints_total_limit."""
    save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
    accelerator.save_state(save_path)
    logger.info(f"Saved state to {save_path}")

    if args.checkpoints_total_limit is not None:
        checkpoints = os.listdir(args.output_dir)
        checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
        checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

        if len(checkpoints) >= args.checkpoints_total_limit:
            num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
            removing_checkpoints = checkpoints[:num_to_remove]

            logger.info(
                f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
            )
            logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

            for removing_checkpoint in removing_checkpoints:
                checkpoint_path = os.path.join(args.output_dir, removing_checkpoint)
                if os.path.exists(checkpoint_path):
                    try:
                        shutil.rmtree(checkpoint_path)
                        logger.info(f'Directory "{checkpoint_path}" has been removed.')
                    except Exception as e:
                        logger.info(f'Error removing directory "{checkpoint_path}": {e}. Continuing with the next item.')
                else:
                    logger.info(f'Directory "{checkpoint_path}" does not exist.')