Sat3DGen / source /training_utils.py
qian43's picture
Upload 115 files
874cec4 verified
import os
import shutil
def save_checkpoint(accelerator, args, global_step, logger):
"""Save model checkpoint and enforce checkpoints_total_limit."""
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
checkpoint_path = os.path.join(args.output_dir, removing_checkpoint)
if os.path.exists(checkpoint_path):
try:
shutil.rmtree(checkpoint_path)
logger.info(f'Directory "{checkpoint_path}" has been removed.')
except Exception as e:
logger.info(f'Error removing directory "{checkpoint_path}": {e}. Continuing with the next item.')
else:
logger.info(f'Directory "{checkpoint_path}" does not exist.')