| import os | |
| # os.environ['CUDA_LAUNCH_BLOCKING']='1' | |
| import numpy as np | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import argparse | |
| from torch.utils.data import DataLoader | |
| from torch.utils.tensorboard import SummaryWriter | |
| from datetime import datetime | |
| # import cv2 | |
| import torch | |
| from torch.nn.utils import clip_grad_norm_ | |
| import yaml | |
| from dist_train_utils import print, get_world_size, get_rank, get_local_rank, barrier, reduce_sum, reduce_mean | |
| from tqdm import tqdm | |
| from lr_utils import CosineAnnealingWarmupRestarts | |
| import ssl | |
| ssl._create_default_https_context = ssl._create_unverified_context | |
| import webdataset as wds | |
| from dataset.get_vqgan_wds import get_dataset, handle_exception | |
| # from dataset.vq_wds import get_dataset, my_sample_decoder, my_sample_prec | |
| import json | |
| import time | |
| from accelerate import Accelerator | |
| from accelerate.utils import set_seed | |
| from data_generation.vqgan.load import load_model | |
| import random | |
| parser = argparse.ArgumentParser() | |
| #ddp | |
| parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') | |
| parser.add_argument("--batch_size", type=int, default=16) | |
| parser.add_argument("--epoch", type=int, default=20) | |
| parser.add_argument("--base_lr", type=float, default=4.5e-6) | |
| parser.add_argument("--log_folder", type=str, default='f16') | |
| args = parser.parse_args() | |
| set_seed(42) | |
| accelerator = Accelerator() | |
| # class MyLoader(wds.WebLoader, torch.utils.data.DataLoader): | |
| # pass | |
| # train_dataloader = MyLoader(train_dataset, ...) | |
| # with open('/mnt/bn/robotics-data-hl/jirong/git/incontextrobotics/models/embedding/VQGAN//model.yaml', encoding='utf-8') as f: | |
| # cfg = yaml.load(f, Loader=yaml.FullLoader) | |
| log_folder = args.log_folder | |
| log_path = os.path.join('/mnt/bn/roboicl-jirong/codebase/DeLVM/logs/', log_folder) | |
| if get_rank() == 0: | |
| if not os.path.exists(log_path): | |
| os.makedirs(log_path) | |
| os.system('chmod -R 777 ' + log_path) | |
| accelerator.wait_for_everyone() | |
| timestamp = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now()) | |
| if get_rank() == 0: | |
| WRITER = SummaryWriter(log_path+'/'+timestamp, max_queue=1000) | |
| print('world_size: {}, rank: {}, local_rank: {}'.format(get_world_size(), get_rank(), get_local_rank()), all=True) | |
| # dist.init_process_group(backend='nccl', rank = get_rank(), world_size=get_world_size()) | |
| # # assign gpu | |
| # torch.cuda.set_device(get_local_rank()) | |
| # # torch.cuda.set_device(7) | |
| # device = torch.cuda.current_device() | |
| device = accelerator.device | |
| print('device:', device, all=True) | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_muse_finetune_calvin_vl_20m_1e-6_256/checkpoint_vq_epoch_48127.tar') | |
| # vqmodel = load_model() | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_1024/checkpoint_vq_epoch_139999.tar') | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_1024/checkpoint_vq_epoch_64999.tar') | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_16384_160m_1e-4_192_disc_50000/checkpoint_vq_epoch_49999.tar') | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_16384_160m_1e-4_192_subset/checkpoint_vq_epoch_549999.tar') | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_16384_160m_3e-4_192/checkpoint_vq_epoch_149999.tar') | |
| # vqmodel = load_model('/mnt/bn/roboicl-jirong/codebase/DeLVM/vqgan_ckpt/ckpt.pth', 8192) | |
| vqmodel = load_model('/mnt/bn/roboicl-jirong/codebase/DeLVM/logs/f16_192_real_calvin_robot_datacomp_1e-5_disc_start_0_weight_0.2_acc_1/checkpoint_vq_epoch_9999.tar', 8192) | |
| # vqmodel = load_model('/mnt/bn/roboicl-jirong/codebase/DeLVM/logs/f16_192_real_calvin_robot_datacomp_1e-5_disc_start_0_weight_0.2_acc_32/checkpoint_vq_epoch_19199.tar', 8192) | |
| # vqmodel = load_model('/mnt/bn/roboicl-jirong/codebase/DeLVM/logs/f16_256_real_calvin_datacomp_1e-5_disc_start_0_weight_0.1/checkpoint_vq_epoch_34999.tar', 8192) | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.3/checkpoint_vq_epoch_89999.tar', 2048) | |
| # vqmodel = load_model(None, 2048) | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.3_resume/checkpoint_vq_epoch_23999.tar', 2048) | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048/checkpoint_vq_epoch_122499.tar', 2048) | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_resume_aug/checkpoint_vq_epoch_97499.tar', 2048) | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/git/DeLVM/vqgan_ckpt/ckpt.pth', 8192) | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_muse_ego4d_llava_calvin_3e-6_disc_0/checkpoint_vq_epoch_66999.tar', 8192) | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.3/checkpoint_vq_epoch_89999.tar', 2048) | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.5_resume/checkpoint_vq_epoch_999.tar', 2048) | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_resume_aug/checkpoint_vq_epoch_97499.tar', 2048) | |
| # copied_param = [] | |
| # init_param = [] | |
| # for ((name, param), (name_s, param_s)) in zip(vq_model_t.named_parameters(), vqmodel.named_parameters()): | |
| # if param.shape == param_s.shape: | |
| # param_s.data = param.clone().data | |
| # copied_param.append(param_s) | |
| # else: | |
| # init_param.append(name_s) | |
| # # print (name, param.shape, name_s, param_s.shape) | |
| # print ('params coopied from teacher') | |
| # print (init_param) | |
| # vqmodel = get_vqmodel(cfg) | |
| # vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_8192_160m_1e-4_muse_pretrained/checkpoint_vq_epoch_49999.tar') | |
| # ckpt = torch.load('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_8192_160m_1e-4_muse_pretrained/checkpoint_vq_epoch_49999.tar', map_location='cpu') | |
| # vqmodel.load_state_dict(ckpt) | |
| vqmodel = nn.SyncBatchNorm.convert_sync_batchnorm(vqmodel) | |
| vqmodel = vqmodel.to(device) | |
| # with open('/mnt/bn/robotics-data-hl/jirong/git/incontextrobotics/dataset/vqgan_imgs/calvin_only.json', 'r') as f: | |
| # tars = json.load(f) | |
| ds_list = [] | |
| # ds0 = get_dataset('/mnt/bn/roboicllq-data1/calvin_img/calvin_img_00000051.tar', seed=42) | |
| # ds_list += [(ds0, 0.05)] | |
| # print (len(tars), '============') | |
| # ds1 = get_dataset('/mnt/bn/roboicllq-data1/processed_real/imgs/real_data_img_{0000..0590}.tar', seed=42) | |
| # ds_list += [(ds1, 0.1)] | |
| # ds2 = get_dataset('/mnt/bn/roboicllq-data1/processed_real/hand_imgs/real_data_hand_img_{0000..0590}.tar', seed=42) | |
| # ds_list += [(ds2, 0.02)] | |
| ds3 = get_dataset('/mnt/bn/roboicllq-data1/calvin_img/calvin_img_{00000000..00000110}.tar', seed=42) | |
| ds_list += [(ds3, 0.1)] | |
| ds6 = get_dataset('/mnt/bn/roboicllq-data1/calvin_img/hands/calvin_hands_img_{00000000..00000110}.tar', seed=42) | |
| ds_list += [(ds6, 0.02)] | |
| with open('/mnt/bn/roboicl-jirong/codebase/RoboICL/robot_img.json', 'r') as f: | |
| tars = json.load(f) | |
| # # # print (len(tars)) | |
| ds4 = get_dataset(tars, seed=42) | |
| ds_list += [(ds4, 0.3)] | |
| with open('/mnt/bn/roboicllq-data1/aligned_robot_ds/calvin/datacomp.json', 'r') as f: | |
| tars = json.load(f) | |
| # # # print (len(tars)) | |
| ds0 = get_dataset(tars, seed=42) | |
| ds_list += [(ds0, 0.68)] | |
| ds = wds.RandomMix(*zip(*ds_list)) | |
| ds = wds.DataPipeline(ds) | |
| loader = ( | |
| wds.WebLoader(ds, num_workers=4, batch_size=args.batch_size, pin_memory=True).with_epoch(args.epoch) | |
| ) | |
| # loader = MyLoader(dataset=trainset, num_workers=4, batch_size=args.batch_size, pin_memory=True).with_epoch(args.epoch) | |
| # base_lr = 4.5e-6 * args.batch_size * get_world_size() | |
| base_lr = 1e-5 | |
| # base_lr = 3e-6 | |
| # base_lr = 1e-6 | |
| # base_lr = 1e-3 | |
| # base_lr = 5e-4 | |
| # base_lr = 1e-4 | |
| # base_lr = 5e-5 | |
| opt, _ = vqmodel.configure_optimizers(base_lr) | |
| # ae_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt[0], T_max=args.epoch, eta_min=base_lr * 0.001) | |
| # disc_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt[1], T_max=args.epoch, eta_min=base_lr * 0.001) | |
| # disc_start = cfg['model']['params']['lossconfig']['params']['disc_start'] | |
| disc_start = 0 | |
| ae_opt, disc_opt = opt[0], opt[1] | |
| ae_scheduler = CosineAnnealingWarmupRestarts(ae_opt, | |
| first_cycle_steps=args.epoch, | |
| cycle_mult=1, | |
| max_lr=base_lr, | |
| min_lr=base_lr*0.9999999, | |
| warmup_steps=args.epoch/10, | |
| gamma=1) | |
| disc_scheduler = CosineAnnealingWarmupRestarts(disc_opt, | |
| first_cycle_steps=args.epoch, | |
| cycle_mult=1, | |
| max_lr=base_lr, | |
| min_lr=base_lr*0.9999999, | |
| warmup_steps=args.epoch/10, | |
| gamma=1) | |
| vqmodel, ae_opt, disc_opt, ae_scheduler, disc_scheduler, loader = accelerator.prepare(vqmodel, ae_opt, disc_opt, ae_scheduler, disc_scheduler, loader) | |
| print ('global config end ---------------------------------------------------') | |
| log_iter = 50 | |
| acc_steps = 1 | |
| def train_one_epoch(args): | |
| STEP_CNT = 0 | |
| DISC_STEP_CNT = 0 | |
| vqmodel.train() | |
| stat_dict = {} | |
| for i, data in enumerate(loader): | |
| batch = data.to(device) | |
| aeloss, log_dict_ae = vqmodel.module.training_step(batch, 0, device, STEP_CNT) | |
| if STEP_CNT >= disc_start: | |
| discloss, log_dict_disc = vqmodel.module.training_step(batch, 1, device, STEP_CNT) | |
| accelerator.backward(aeloss) | |
| # if (STEP_CNT + 1) == 0: | |
| ae_opt.step() | |
| ae_opt.zero_grad() | |
| if STEP_CNT >= disc_start: | |
| # print('asodkjsaoifdjosjio') | |
| accelerator.backward(discloss) | |
| # if (STEP_CNT + 1) == 0: | |
| disc_opt.step() | |
| disc_opt.zero_grad() | |
| # STEP_CNT += 1 | |
| if (STEP_CNT + 1) % log_iter == 0: | |
| if get_rank() == 0: | |
| for k, v in log_dict_ae.items(): | |
| if k not in stat_dict.keys(): | |
| stat_dict[k] = 0 | |
| stat_dict[k] += v.cpu().item() | |
| if STEP_CNT >= disc_start: | |
| for k, v in log_dict_disc.items(): | |
| if k not in stat_dict.keys(): | |
| stat_dict[k] = 0 | |
| stat_dict[k] += v.cpu().item() | |
| for k, v in stat_dict.items(): | |
| WRITER.add_scalar(k, v/log_iter, STEP_CNT*args.batch_size*get_world_size()) | |
| stat_dict[k] = 0 | |
| WRITER.add_scalar('lr_ae', ae_opt.param_groups[0]['lr'], STEP_CNT*args.batch_size*get_world_size()) | |
| WRITER.add_scalar('lr_disc', disc_opt.param_groups[0]['lr'], STEP_CNT*args.batch_size*get_world_size()) | |
| STEP_CNT += 1 | |
| # i += 1 | |
| ae_scheduler.step(STEP_CNT) | |
| if STEP_CNT >= disc_start: | |
| DISC_STEP_CNT += 1 | |
| disc_scheduler.step(STEP_CNT) | |
| if (STEP_CNT + 1) % (log_iter*1000) == 0: | |
| save_dict = {} | |
| if get_rank() == 0: | |
| unwrapped_model = accelerator.unwrap_model(vqmodel) | |
| accelerator.save(unwrapped_model.state_dict(), os.path.join(log_path, 'checkpoint_vq_epoch_' +str(STEP_CNT)+'.tar')) | |
| # barrier() | |
| accelerator.wait_for_everyone() | |
| print ('epoch {} train done'.format(STEP_CNT)) | |
| def main(args): | |
| train_one_epoch(args) | |
| if __name__ == '__main__': | |
| import time | |
| # train_one_epoch() | |
| main(args) | |
| time.sleep(60) | |
| dist.destroy_process_group() | |
| print ('train done!') | |