| |
| |
|
|
| |
| |
|
|
| import logging |
| import math |
| import os |
| import random |
| import re |
| from datetime import timedelta |
| from typing import Optional |
|
|
| import hydra |
|
|
| import numpy as np |
| import omegaconf |
| import torch |
| import torch.distributed as dist |
| from iopath.common.file_io import g_pathmgr |
| from omegaconf import OmegaConf |
|
|
|
|
| def multiply_all(*args): |
| return np.prod(np.array(args)).item() |
|
|
|
|
| def collect_dict_keys(config): |
| """This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined""" |
| val_keys = [] |
| |
| if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]): |
| val_keys.append(config["dict_key"]) |
| else: |
| |
| for v in config.values(): |
| if isinstance(v, type(config)): |
| val_keys.extend(collect_dict_keys(v)) |
| elif isinstance(v, omegaconf.listconfig.ListConfig): |
| for item in v: |
| if isinstance(item, type(config)): |
| val_keys.extend(collect_dict_keys(item)) |
| return val_keys |
|
|
|
|
| class Phase: |
| TRAIN = "train" |
| VAL = "val" |
|
|
|
|
| def register_omegaconf_resolvers(): |
| OmegaConf.register_new_resolver("get_method", hydra.utils.get_method) |
| OmegaConf.register_new_resolver("get_class", hydra.utils.get_class) |
| OmegaConf.register_new_resolver("add", lambda x, y: x + y) |
| OmegaConf.register_new_resolver("times", multiply_all) |
| OmegaConf.register_new_resolver("divide", lambda x, y: x / y) |
| OmegaConf.register_new_resolver("pow", lambda x, y: x**y) |
| OmegaConf.register_new_resolver("subtract", lambda x, y: x - y) |
| OmegaConf.register_new_resolver("range", lambda x: list(range(x))) |
| OmegaConf.register_new_resolver("int", lambda x: int(x)) |
| OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x))) |
| OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x)) |
|
|
|
|
| def setup_distributed_backend(backend, timeout_mins): |
| """ |
| Initialize torch.distributed and set the CUDA device. |
| Expects environment variables to be set as per |
| https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization |
| along with the environ variable "LOCAL_RANK" which is used to set the CUDA device. |
| """ |
| |
| |
| os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" |
| logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins") |
| dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins)) |
| return dist.get_rank() |
|
|
|
|
| def get_machine_local_and_dist_rank(): |
| """ |
| Get the distributed and local rank of the current gpu. |
| """ |
| local_rank = int(os.environ.get("LOCAL_RANK", None)) |
| distributed_rank = int(os.environ.get("RANK", None)) |
| assert ( |
| local_rank is not None and distributed_rank is not None |
| ), "Please the set the RANK and LOCAL_RANK environment variables." |
| return local_rank, distributed_rank |
|
|
|
|
| def print_cfg(cfg): |
| """ |
| Supports printing both Hydra DictConfig and also the AttrDict config |
| """ |
| logging.info("Training with config:") |
| logging.info(OmegaConf.to_yaml(cfg)) |
|
|
|
|
| def set_seeds(seed_value, max_epochs, dist_rank): |
| """ |
| Set the python random, numpy and torch seed for each gpu. Also set the CUDA |
| seeds if the CUDA is available. This ensures deterministic nature of the training. |
| """ |
| |
| seed_value = (seed_value + dist_rank) * max_epochs |
| logging.info(f"MACHINE SEED: {seed_value}") |
| random.seed(seed_value) |
| np.random.seed(seed_value) |
| torch.manual_seed(seed_value) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed_value) |
|
|
|
|
| def makedir(dir_path): |
| """ |
| Create the directory if it does not exist. |
| """ |
| is_success = False |
| try: |
| if not g_pathmgr.exists(dir_path): |
| g_pathmgr.mkdirs(dir_path) |
| is_success = True |
| except BaseException: |
| logging.info(f"Error creating directory: {dir_path}") |
| return is_success |
|
|
|
|
| def is_dist_avail_and_initialized(): |
| if not dist.is_available(): |
| return False |
| if not dist.is_initialized(): |
| return False |
| return True |
|
|
|
|
| def get_amp_type(amp_type: Optional[str] = None): |
| if amp_type is None: |
| return None |
| assert amp_type in ["bfloat16", "float16"], "Invalid Amp type." |
| if amp_type == "bfloat16": |
| return torch.bfloat16 |
| else: |
| return torch.float16 |
|
|
|
|
| def log_env_variables(): |
| env_keys = sorted(list(os.environ.keys())) |
| st = "" |
| for k in env_keys: |
| v = os.environ[k] |
| st += f"{k}={v}\n" |
| logging.info("Logging ENV_VARIABLES") |
| logging.info(st) |
|
|
|
|
| class AverageMeter: |
| """Computes and stores the average and current value""" |
|
|
| def __init__(self, name, device, fmt=":f"): |
| self.name = name |
| self.fmt = fmt |
| self.device = device |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
| self._allow_updates = True |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
| def __str__(self): |
| fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})" |
| return fmtstr.format(**self.__dict__) |
|
|
|
|
| class MemMeter: |
| """Computes and stores the current, avg, and max of peak Mem usage per iteration""" |
|
|
| def __init__(self, name, device, fmt=":f"): |
| self.name = name |
| self.fmt = fmt |
| self.device = device |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.peak = 0 |
| self.sum = 0 |
| self.count = 0 |
| self._allow_updates = True |
|
|
| def update(self, n=1, reset_peak_usage=True): |
| self.val = torch.cuda.max_memory_allocated() // 1e9 |
| self.sum += self.val * n |
| self.count += n |
| self.avg = self.sum / self.count |
| self.peak = max(self.peak, self.val) |
| if reset_peak_usage: |
| torch.cuda.reset_peak_memory_stats() |
|
|
| def __str__(self): |
| fmtstr = ( |
| "{name}: {val" |
| + self.fmt |
| + "} ({avg" |
| + self.fmt |
| + "}/{peak" |
| + self.fmt |
| + "})" |
| ) |
| return fmtstr.format(**self.__dict__) |
|
|
|
|
| def human_readable_time(time_seconds): |
| time = int(time_seconds) |
| minutes, seconds = divmod(time, 60) |
| hours, minutes = divmod(minutes, 60) |
| days, hours = divmod(hours, 24) |
| return f"{days:02}d {hours:02}h {minutes:02}m" |
|
|
|
|
| class DurationMeter: |
| def __init__(self, name, device, fmt=":f"): |
| self.name = name |
| self.device = device |
| self.fmt = fmt |
| self.val = 0 |
|
|
| def reset(self): |
| self.val = 0 |
|
|
| def update(self, val): |
| self.val = val |
|
|
| def add(self, val): |
| self.val += val |
|
|
| def __str__(self): |
| return f"{self.name}: {human_readable_time(self.val)}" |
|
|
|
|
| class ProgressMeter: |
| def __init__(self, num_batches, meters, real_meters, prefix=""): |
| self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
| self.meters = meters |
| self.real_meters = real_meters |
| self.prefix = prefix |
|
|
| def display(self, batch, enable_print=False): |
| entries = [self.prefix + self.batch_fmtstr.format(batch)] |
| entries += [str(meter) for meter in self.meters] |
| entries += [ |
| " | ".join( |
| [ |
| f"{os.path.join(name, subname)}: {val:.4f}" |
| for subname, val in meter.compute().items() |
| ] |
| ) |
| for name, meter in self.real_meters.items() |
| ] |
| logging.info(" | ".join(entries)) |
| if enable_print: |
| print(" | ".join(entries)) |
|
|
| def _get_batch_fmtstr(self, num_batches): |
| num_digits = len(str(num_batches // 1)) |
| fmt = "{:" + str(num_digits) + "d}" |
| return "[" + fmt + "/" + fmt.format(num_batches) + "]" |
|
|
|
|
| def get_resume_checkpoint(checkpoint_save_dir): |
| if not g_pathmgr.isdir(checkpoint_save_dir): |
| return None |
| ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt") |
| if not g_pathmgr.isfile(ckpt_file): |
| return None |
|
|
| return ckpt_file |
|
|