| import os |
| import shutil |
| from watchdog.observers import Observer |
| from watchdog.events import FileSystemEventHandler |
| import time |
| import re |
|
|
| class CheckpointHandler(FileSystemEventHandler): |
| def __init__(self, folder_path, max_checkpoints=2): |
| self.folder_path = folder_path |
| self.max_checkpoints = max_checkpoints |
|
|
| def on_created(self, event): |
| if not event.is_directory: |
| return |
| |
|
|
| def cleanup_checkpoints(self): |
| |
| checkpoints = [os.path.join(self.folder_path, d) for d in os.listdir(self.folder_path) if os.path.isdir(os.path.join(self.folder_path, d))] |
| |
| |
| checkpoints = [checkpoint for checkpoint in checkpoints if re.match(r'global_step_\d+', os.path.basename(checkpoint))] |
|
|
| |
| checkpoints_with_time = [(os.path.getctime(checkpoint), checkpoint) for checkpoint in checkpoints] |
| checkpoints_with_time.sort() |
| |
| specific_checkpoints = {f"global_step_{i}" for i in [45, 90, 135, 180, 220]} |
|
|
| |
| if len(checkpoints_with_time) <= self.max_checkpoints: |
| print(f"No need to remove any checkpoints, {len(checkpoints_with_time)} checkpoints exist") |
| else: |
| for _, checkpoint in checkpoints_with_time[:-self.max_checkpoints]: |
| checkpoint_name = os.path.basename(checkpoint) |
| if checkpoint_name not in specific_checkpoints: |
| shutil.rmtree(checkpoint) |
| print(f"Removed old checkpoint: {checkpoint}") |
| else: |
| print(f"Skipped specific checkpoint: {checkpoint}") |
|
|
| def main(): |
| folder_path = '/data/wuxinrui/easyr1_checkpoints/1_5B_TCMv2_long_short_regular_budget_modified' |
| event_handler = CheckpointHandler(folder_path) |
| observer = Observer() |
| observer.schedule(event_handler, folder_path, recursive=False) |
| observer.start() |
|
|
| try: |
| while True: |
| event_handler.cleanup_checkpoints() |
| time.sleep(300) |
| except KeyboardInterrupt: |
| observer.stop() |
| observer.join() |
|
|
| if __name__ == "__main__": |
| main() |