| import json |
| import os |
| import sys |
|
|
| from tap import Tap |
|
|
| import dist |
|
|
|
|
| class Args(Tap): |
| |
| exp_name: str = 'mamba' |
| exp_dir: str = '' |
| data_path: str = '' |
| init_weight: str = '' |
| resume_from: str = '' |
| |
| |
| mask: float = 0.75 |
| |
| |
| model: str = 'mambamim' |
| input_size: int = 96 |
| sbn: bool = True |
| |
| |
| bs: int = 1 |
| dataloader_workers: int = 8 |
| |
| |
| dp: float = 0.0 |
| base_lr: float = 1e-4 |
| wd: float = 0.04 |
| wde: float = 0.2 |
| ep: int = 100 |
| wp_ep: int = 40 |
| clip: int = 5. |
| opt: str = 'adamw' |
| ada: float = 0. |
| |
| |
| lr: float = 1e-4 |
| batch_size_per_gpu: int = 0 |
| glb_batch_size: int = 0 |
| densify_norm: str = '' |
| device: str = 'gpu' |
| local_rank: int = 0 |
| cmd: str = ' '.join(sys.argv[1:]) |
| commit_id: str = os.popen(f'git rev-parse HEAD').read().strip() or '[unknown]' |
| commit_msg: str = (os.popen(f'git log -1').read().strip().splitlines() or ['[unknown]'])[-1].strip() |
| last_loss: float = 0. |
| cur_ep: str = '' |
| remain_time: str = '' |
| finish_time: str = '' |
| first_logging: bool = True |
| log_txt_name: str = '{args.exp_dir}/pretrain_log.txt' |
| tb_lg_dir: str = '' |
| |
| @property |
| def is_convnext(self): |
| return 'convnext' in self.model or 'cnx' in self.model |
| |
| @property |
| def is_resnet(self): |
| return 'resnet' in self.model |
| |
| def log_epoch(self): |
| if not dist.is_local_master(): |
| return |
| |
| if self.first_logging: |
| self.first_logging = False |
| with open(self.log_txt_name, 'w') as fp: |
| json.dump({ |
| 'name': self.exp_name, 'cmd': self.cmd, 'git_commit_id': self.commit_id, 'git_commit_msg': self.commit_msg, |
| 'model': self.model, |
| }, fp) |
| fp.write('\n\n') |
| |
| with open(self.log_txt_name, 'a') as fp: |
| json.dump({ |
| 'cur_ep': self.cur_ep, |
| 'last_L': self.last_loss, |
| 'rema': self.remain_time, 'fini': self.finish_time, |
| }, fp) |
| fp.write('\n') |
|
|
|
|
| def init_dist_and_get_args(): |
| from utils import misc |
| |
| |
| args = Args(explicit_bool=True).parse_args() |
| e = os.path.abspath(args.exp_dir) |
| d, e = os.path.dirname(e), os.path.basename(e) |
| e = ''.join(ch if (ch.isalnum() or ch == '-') else '_' for ch in e) |
| args.exp_dir = os.path.join(d, e) |
| |
| os.makedirs(args.exp_dir, exist_ok=True) |
| args.log_txt_name = os.path.join(args.exp_dir, 'pretrain_log.txt') |
| args.tb_lg_dir = args.tb_lg_dir or os.path.join(args.exp_dir, 'tensorboard_log') |
| try: |
| os.makedirs(args.tb_lg_dir, exist_ok=True) |
| except: |
| pass |
| |
| misc.init_distributed_environ(exp_dir=args.exp_dir) |
| |
| |
| if not dist.initialized(): |
| args.sbn = False |
| args.first_logging = True |
| args.device = dist.get_device() |
| args.batch_size_per_gpu = args.bs // dist.get_world_size() |
| args.glb_batch_size = args.batch_size_per_gpu * dist.get_world_size() |
| |
|
|
| args.ada = args.ada or 0.999 |
| args.densify_norm = 'ln' |
| |
| args.opt = args.opt.lower() |
| args.lr = args.base_lr |
| args.wde = args.wde or args.wd |
| |
| return args |
|
|