| import sys |
|
|
| sys.dont_write_bytecode = True |
|
|
| import os |
| import re |
| import glob |
| import time |
| import torch |
| import argparse |
| import subprocess |
| import torch.multiprocessing as mp |
|
|
| from core.config import Config |
| from core import Trainer |
|
|
| def main(rank, config): |
| trainer = Trainer(rank, config) |
| trainer.train_loop() |
|
|
| if __name__ == "__main__": |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config', type=str, default=None, help='Name of config file') |
| parser.add_argument('--seed', type=int, default=-1, help='Seed') |
| parser.add_argument('--device', type=int, default=-1, help='Device') |
| args = parser.parse_args() |
|
|
| if args.config: |
| args.config = args.config + '.yaml' if not args.config.endswith('.yaml') else args.config |
| config_files = glob.glob(f'./config/**/{args.config}', recursive=True) |
| assert len(config_files) == 1, "Config files conflict" |
| config_path = config_files[0] |
| config = Config(config_path).get_config_dict() |
| else: |
| config = Config("./config/InfLoRA.yaml").get_config_dict() |
|
|
| if config['device_ids'] == 'auto': |
| least_utilized_device = 0 |
| lowest_utilization = float('inf') |
|
|
| try: |
| result = subprocess.run( |
| ['nvidia-smi', '--query-gpu=index,memory.used,memory.total,utilization.gpu', '--format=csv,noheader,nounits'], |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| text=True |
| ) |
| if result.returncode != 0: |
| raise RuntimeError(f"nvidia-smi error: {result.stderr}") |
|
|
| gpu_info = result.stdout.strip().split('\n') |
| gpu_utilization = [] |
|
|
| for gpu in gpu_info: |
| match = re.match(r'(\d+),\s*(\d+),\s*(\d+),\s*(\d+)', gpu) |
| if match: |
| device_id, mem_used, mem_total, gpu_util = map(int, match.groups()) |
| |
| utilization_score = gpu_util + (mem_used / mem_total) * 100 |
| gpu_utilization.append((device_id, utilization_score)) |
|
|
| |
| gpu_utilization.sort(key=lambda x: x[1]) |
| config["device_ids"] = [str(gpu[0]) for gpu in gpu_utilization[:config["n_gpu"]]] |
|
|
| except Exception as e: |
| config["device_ids"] = range(config["n_gpu"]) |
| print(f"Error while querying GPUs: {e}, using default device {config['device_ids']}") |
|
|
| if args.seed > -1: |
| print(f'Seed : {config["seed"]} -> {args.seed}') |
| config['seed'] = args.seed |
|
|
| if args.device > -1: |
| config['device_ids'] = args.device |
|
|
| if not isinstance(config['device_ids'], list): |
| config['device_ids'] = [config['device_ids']] |
|
|
| print(f'Selected GPUs: {config["device_ids"]}') |
|
|
| if config["n_gpu"] > 1: |
| mp.spawn(main, nprocs=config["n_gpu"], args=(config,)) |
| pass |
| os.environ["CUDA_VISIBLE_DEVICES"] = config["device_ids"] |
| else: |
| main(0, config) |
|
|