| import argparse
|
| import os
|
| import sys
|
| from abc import ABC
|
| from typing import Type
|
|
|
|
|
| class DefaultConfigs(ABC):
|
|
|
| gpus = [0]
|
| seed = 3407
|
| arch = "resnet50"
|
| datasets = ["zhaolian_train"]
|
| datasets_test = ["adm_res_abs_ddim20s"]
|
| mode = "binary"
|
| class_bal = False
|
| batch_size = 64
|
| loadSize = 256
|
| cropSize = 224
|
| epoch = "latest"
|
| num_workers = 20
|
| serial_batches = False
|
| isTrain = True
|
|
|
|
|
| rz_interp = ["bilinear"]
|
|
|
| blur_prob = 0.1
|
| blur_sig = [0.5]
|
|
|
| jpg_prob = 0.1
|
| jpg_method = ["cv2"]
|
| jpg_qual = [75]
|
| gray_prob = 0.0
|
| aug_resize = True
|
| aug_crop = True
|
| aug_flip = True
|
| aug_norm = True
|
|
|
|
|
| warmup = False
|
|
|
| warmup_epoch = 3
|
| earlystop = True
|
| earlystop_epoch = 5
|
| optim = "adam"
|
| new_optim = False
|
| loss_freq = 400
|
| save_latest_freq = 2000
|
| save_epoch_freq = 20
|
| continue_train = False
|
| epoch_count = 1
|
| last_epoch = -1
|
| nepoch = 400
|
| beta1 = 0.9
|
| lr = 0.0001
|
| init_type = "normal"
|
| init_gain = 0.02
|
| pretrained = True
|
|
|
|
|
| root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| dataset_root = os.path.join(root_dir, "data")
|
| exp_root = os.path.join(root_dir, "data", "exp")
|
| _exp_name = ""
|
| exp_dir = ""
|
| ckpt_dir = ""
|
| logs_path = ""
|
| ckpt_path = ""
|
|
|
| @property
|
| def exp_name(self):
|
| return self._exp_name
|
|
|
| @exp_name.setter
|
| def exp_name(self, value: str):
|
| self._exp_name = value
|
| self.exp_dir: str = os.path.join(self.exp_root, self.exp_name)
|
| self.ckpt_dir: str = os.path.join(self.exp_dir, "ckpt")
|
| self.logs_path: str = os.path.join(self.exp_dir, "logs.txt")
|
|
|
| os.makedirs(self.exp_dir, exist_ok=True)
|
| os.makedirs(self.ckpt_dir, exist_ok=True)
|
|
|
| def to_dict(self):
|
| dic = {}
|
| for fieldkey in dir(self):
|
| fieldvalue = getattr(self, fieldkey)
|
| if not fieldkey.startswith("__") and not callable(fieldvalue) and not fieldkey.startswith("_"):
|
| dic[fieldkey] = fieldvalue
|
| return dic
|
|
|
|
|
| def args_list2dict(arg_list: list):
|
| assert len(arg_list) % 2 == 0, f"Override list has odd length: {arg_list}; it must be a list of pairs"
|
| return dict(zip(arg_list[::2], arg_list[1::2]))
|
|
|
|
|
| def str2bool(v: str) -> bool:
|
| if isinstance(v, bool):
|
| return v
|
| elif v.lower() in ("true", "yes", "on", "y", "t", "1"):
|
| return True
|
| elif v.lower() in ("false", "no", "off", "n", "f", "0"):
|
| return False
|
| else:
|
| return bool(v)
|
|
|
|
|
| def str2list(v: str, element_type=None) -> list:
|
| if not isinstance(v, (list, tuple, set)):
|
| v = v.lstrip("[").rstrip("]")
|
| v = v.split(",")
|
| v = list(map(str.strip, v))
|
| if element_type is not None:
|
| v = list(map(element_type, v))
|
| return v
|
|
|
|
|
| CONFIGCLASS = Type[DefaultConfigs]
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--gpus", default=[0], type=int, nargs="+")
|
| parser.add_argument("--exp_name", default="", type=str)
|
| parser.add_argument("--ckpt", default="model_epoch_latest.pth", type=str)
|
| parser.add_argument("opts", default=[], nargs=argparse.REMAINDER)
|
| args = parser.parse_args()
|
|
|
| if os.path.exists(os.path.join(DefaultConfigs.exp_root, args.exp_name, "config.py")):
|
| sys.path.insert(0, os.path.join(DefaultConfigs.exp_root, args.exp_name))
|
| from config import cfg
|
|
|
| cfg: CONFIGCLASS
|
| else:
|
| cfg = DefaultConfigs()
|
|
|
| if args.opts:
|
| opts = args_list2dict(args.opts)
|
| for k, v in opts.items():
|
| if not hasattr(cfg, k):
|
| raise ValueError(f"Unrecognized option: {k}")
|
| original_type = type(getattr(cfg, k))
|
| if original_type == bool:
|
| setattr(cfg, k, str2bool(v))
|
| elif original_type in (list, tuple, set):
|
| setattr(cfg, k, str2list(v, type(getattr(cfg, k)[0])))
|
| else:
|
| setattr(cfg, k, original_type(v))
|
|
|
| cfg.gpus: list = args.gpus
|
| os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(gpu) for gpu in cfg.gpus])
|
| cfg.exp_name = args.exp_name
|
| cfg.ckpt_path: str = os.path.join(cfg.ckpt_dir, args.ckpt)
|
|
|
| if isinstance(cfg.datasets, str):
|
| cfg.datasets = cfg.datasets.split(",")
|
|
|