| import os
|
| import sys
|
| import time
|
| import __main__
|
| import traceback
|
| import mlflow
|
|
|
| import colbert.utils.distributed as distributed
|
|
|
| from contextlib import contextmanager
|
| from colbert.utils.logging import Logger
|
| from colbert.utils.utils import timestamp, create_directory, print_message
|
|
|
|
|
| class _RunManager():
|
| def __init__(self):
|
| self.experiments_root = None
|
| self.experiment = None
|
| self.path = None
|
| self.script = self._get_script_name()
|
| self.name = self._generate_default_run_name()
|
| self.original_name = self.name
|
| self.exit_status = 'FINISHED'
|
|
|
| self._logger = None
|
| self.start_time = time.time()
|
|
|
| def init(self, rank, root, experiment, name):
|
| assert '/' not in experiment, experiment
|
| assert '/' not in name, name
|
|
|
| self.experiments_root = os.path.abspath(root)
|
| self.experiment = experiment
|
| self.name = name
|
| self.path = os.path.join(self.experiments_root, self.experiment, self.script, self.name)
|
|
|
| if rank < 1:
|
| if os.path.exists(self.path):
|
| print('\n\n')
|
| print_message("It seems that ", self.path, " already exists.")
|
| print_message("Do you want to overwrite it? \t yes/no \n")
|
|
|
|
|
|
|
| response = input()
|
| if response.strip() != 'yes':
|
| assert not os.path.exists(self.path), self.path
|
| else:
|
| create_directory(self.path)
|
|
|
| distributed.barrier(rank)
|
|
|
| self._logger = Logger(rank, self)
|
| self._log_args = self._logger._log_args
|
| self.warn = self._logger.warn
|
| self.info = self._logger.info
|
| self.info_all = self._logger.info_all
|
| self.log_metric = self._logger.log_metric
|
| self.log_new_artifact = self._logger.log_new_artifact
|
|
|
| def _generate_default_run_name(self):
|
| return timestamp()
|
|
|
| def _get_script_name(self):
|
| return os.path.basename(__main__.__file__) if '__file__' in dir(__main__) else 'none'
|
|
|
| @contextmanager
|
| def context(self, consider_failed_if_interrupted=True):
|
| try:
|
| yield
|
|
|
| except KeyboardInterrupt as ex:
|
| print('\n\nInterrupted\n\n')
|
| self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
|
| self._logger._log_all_artifacts()
|
|
|
| if consider_failed_if_interrupted:
|
| self.exit_status = 'KILLED'
|
|
|
| sys.exit(128 + 2)
|
|
|
| except Exception as ex:
|
| self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
|
| self._logger._log_all_artifacts()
|
|
|
| self.exit_status = 'FAILED'
|
|
|
| raise ex
|
|
|
| finally:
|
| total_seconds = str(time.time() - self.start_time) + '\n'
|
| original_name = str(self.original_name)
|
| name = str(self.name)
|
|
|
| self.log_new_artifact(os.path.join(self._logger.logs_path, 'elapsed.txt'), total_seconds)
|
| self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.original.txt'), original_name)
|
| self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.txt'), name)
|
|
|
| self._logger._log_all_artifacts()
|
|
|
| mlflow.end_run(status=self.exit_status)
|
|
|
|
|
| Run = _RunManager()
|
|
|