| import json |
| import logging |
| import os |
| import subprocess |
| from argparse import ArgumentParser |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def parse_args(): |
| parser = ArgumentParser() |
| parsed, unknown = parser.parse_known_args() |
| for arg in unknown: |
| if arg.startswith(("-", "--")): |
| parser.add_argument(arg.split("=")[0]) |
|
|
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| port = 8888 |
| num_gpus = int(os.environ["SM_NUM_GPUS"]) |
| hosts = json.loads(os.environ["SM_HOSTS"]) |
| num_nodes = len(hosts) |
| current_host = os.environ["SM_CURRENT_HOST"] |
| rank = hosts.index(current_host) |
| os.environ["NCCL_DEBUG"] = "INFO" |
|
|
| if num_nodes > 1: |
| cmd = f"""python -m torch.distributed.launch \ |
| --nnodes={num_nodes} \ |
| --node_rank={rank} \ |
| --nproc_per_node={num_gpus} \ |
| --master_addr={hosts[0]} \ |
| --master_port={port} \ |
| ./run_glue.py \ |
| {"".join([f" --{parameter} {value}" for parameter,value in args.__dict__.items()])}""" |
| else: |
| cmd = f"""python -m torch.distributed.launch \ |
| --nproc_per_node={num_gpus} \ |
| ./run_glue.py \ |
| {"".join([f" --{parameter} {value}" for parameter,value in args.__dict__.items()])}""" |
| try: |
| subprocess.run(cmd, shell=True) |
| except Exception as e: |
| logger.info(e) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|