| |
| |
| |
|
|
| import os |
| import time |
| from copy import deepcopy |
|
|
| import torch.backends.cudnn |
| import torch.distributed |
| import torch.nn as nn |
|
|
| from src.efficientvit.apps.data_provider import DataProvider |
| from src.efficientvit.apps.trainer.run_config import RunConfig |
| from src.efficientvit.apps.utils import (dist_init, dump_config, |
| get_dist_local_rank, get_dist_rank, |
| get_dist_size, init_modules, is_master, |
| load_config, partial_update_config, |
| zero_last_gamma) |
| from src.efficientvit.models.utils import (build_kwargs_from_config, |
| load_state_dict_from_file) |
|
|
| __all__ = [ |
| "save_exp_config", |
| "setup_dist_env", |
| "setup_seed", |
| "setup_exp_config", |
| "setup_data_provider", |
| "setup_run_config", |
| "init_model", |
| ] |
|
|
|
|
| def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None: |
| if not is_master(): |
| return |
| dump_config(exp_config, os.path.join(path, name)) |
|
|
|
|
| def setup_dist_env(gpu: str or None = None) -> None: |
| if gpu is not None: |
| os.environ["CUDA_VISIBLE_DEVICES"] = gpu |
| if not torch.distributed.is_initialized(): |
| dist_init() |
| torch.backends.cudnn.benchmark = True |
| torch.cuda.set_device(get_dist_local_rank()) |
|
|
|
|
| def setup_seed(manual_seed: int, resume: bool) -> None: |
| if resume: |
| manual_seed = int(time.time()) |
| manual_seed = get_dist_rank() + manual_seed |
| torch.manual_seed(manual_seed) |
| torch.cuda.manual_seed_all(manual_seed) |
|
|
|
|
| def setup_exp_config( |
| config_path: str, recursive=True, opt_args: dict or None = None |
| ) -> dict: |
| |
| if not os.path.isfile(config_path): |
| raise ValueError(config_path) |
|
|
| fpaths = [config_path] |
| if recursive: |
| extension = os.path.splitext(config_path)[1] |
| while os.path.dirname(config_path) != config_path: |
| config_path = os.path.dirname(config_path) |
| fpath = os.path.join(config_path, "default" + extension) |
| if os.path.isfile(fpath): |
| fpaths.append(fpath) |
| fpaths = fpaths[::-1] |
|
|
| default_config = load_config(fpaths[0]) |
| exp_config = deepcopy(default_config) |
| for fpath in fpaths[1:]: |
| partial_update_config(exp_config, load_config(fpath)) |
| |
| if opt_args is not None: |
| partial_update_config(exp_config, opt_args) |
|
|
| return exp_config |
|
|
|
|
| def setup_data_provider( |
| exp_config: dict, |
| data_provider_classes: list[type[DataProvider]], |
| is_distributed: bool = True, |
| ) -> DataProvider: |
| dp_config = exp_config["data_provider"] |
| dp_config["num_replicas"] = get_dist_size() if is_distributed else None |
| dp_config["rank"] = get_dist_rank() if is_distributed else None |
| dp_config["test_batch_size"] = ( |
| dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2 |
| ) |
| dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config[ |
| "base_batch_size" |
| ] |
|
|
| data_provider_lookup = { |
| provider.name: provider for provider in data_provider_classes |
| } |
| data_provider_class = data_provider_lookup[dp_config["dataset"]] |
|
|
| data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class) |
| data_provider = data_provider_class(**data_provider_kwargs) |
| return data_provider |
|
|
|
|
| def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig: |
| exp_config["run_config"]["init_lr"] = ( |
| exp_config["run_config"]["base_lr"] * get_dist_size() |
| ) |
|
|
| run_config = run_config_cls(**exp_config["run_config"]) |
|
|
| return run_config |
|
|
|
|
| def init_model( |
| network: nn.Module, |
| init_from: str or None = None, |
| backbone_init_from: str or None = None, |
| rand_init="trunc_normal", |
| last_gamma=None, |
| ) -> None: |
| |
| init_modules(network, init_type=rand_init) |
| |
| if last_gamma is not None: |
| zero_last_gamma(network, last_gamma) |
|
|
| |
| if init_from is not None and os.path.isfile(init_from): |
| network.load_state_dict(load_state_dict_from_file(init_from)) |
| print(f"Loaded init from {init_from}") |
| elif backbone_init_from is not None and os.path.isfile(backbone_init_from): |
| network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from)) |
| print(f"Loaded backbone init from {backbone_init_from}") |
| else: |
| print(f"Random init ({rand_init}) with last gamma {last_gamma}") |
|
|