DeLVM / dist_train_vqgan.py
jirong's picture
Upload folder using huggingface_hub
ee3e701 verified
import os
# os.environ['CUDA_LAUNCH_BLOCKING']='1'
import numpy as np
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import argparse
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
# import cv2
import torch
from torch.nn.utils import clip_grad_norm_
import yaml
from dist_train_utils import print, get_world_size, get_rank, get_local_rank, barrier, reduce_sum, reduce_mean
from tqdm import tqdm
from lr_utils import CosineAnnealingWarmupRestarts
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import webdataset as wds
from dataset.get_vqgan_wds import get_dataset, handle_exception
# from dataset.vq_wds import get_dataset, my_sample_decoder, my_sample_prec
import json
import time
from accelerate import Accelerator
from accelerate.utils import set_seed
from data_generation.vqgan.load import load_model
import random
parser = argparse.ArgumentParser()
#ddp
parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--epoch", type=int, default=20)
parser.add_argument("--base_lr", type=float, default=4.5e-6)
parser.add_argument("--log_folder", type=str, default='f16')
args = parser.parse_args()
set_seed(42)
accelerator = Accelerator()
# class MyLoader(wds.WebLoader, torch.utils.data.DataLoader):
# pass
# train_dataloader = MyLoader(train_dataset, ...)
# with open('/mnt/bn/robotics-data-hl/jirong/git/incontextrobotics/models/embedding/VQGAN//model.yaml', encoding='utf-8') as f:
# cfg = yaml.load(f, Loader=yaml.FullLoader)
log_folder = args.log_folder
log_path = os.path.join('/mnt/bn/roboicl-jirong/codebase/DeLVM/logs/', log_folder)
if get_rank() == 0:
if not os.path.exists(log_path):
os.makedirs(log_path)
os.system('chmod -R 777 ' + log_path)
accelerator.wait_for_everyone()
timestamp = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
if get_rank() == 0:
WRITER = SummaryWriter(log_path+'/'+timestamp, max_queue=1000)
print('world_size: {}, rank: {}, local_rank: {}'.format(get_world_size(), get_rank(), get_local_rank()), all=True)
# dist.init_process_group(backend='nccl', rank = get_rank(), world_size=get_world_size())
# # assign gpu
# torch.cuda.set_device(get_local_rank())
# # torch.cuda.set_device(7)
# device = torch.cuda.current_device()
device = accelerator.device
print('device:', device, all=True)
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_muse_finetune_calvin_vl_20m_1e-6_256/checkpoint_vq_epoch_48127.tar')
# vqmodel = load_model()
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_1024/checkpoint_vq_epoch_139999.tar')
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_1024/checkpoint_vq_epoch_64999.tar')
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_16384_160m_1e-4_192_disc_50000/checkpoint_vq_epoch_49999.tar')
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_16384_160m_1e-4_192_subset/checkpoint_vq_epoch_549999.tar')
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_16384_160m_3e-4_192/checkpoint_vq_epoch_149999.tar')
# vqmodel = load_model('/mnt/bn/roboicl-jirong/codebase/DeLVM/vqgan_ckpt/ckpt.pth', 8192)
vqmodel = load_model('/mnt/bn/roboicl-jirong/codebase/DeLVM/logs/f16_192_real_calvin_robot_datacomp_1e-5_disc_start_0_weight_0.2_acc_1/checkpoint_vq_epoch_9999.tar', 8192)
# vqmodel = load_model('/mnt/bn/roboicl-jirong/codebase/DeLVM/logs/f16_192_real_calvin_robot_datacomp_1e-5_disc_start_0_weight_0.2_acc_32/checkpoint_vq_epoch_19199.tar', 8192)
# vqmodel = load_model('/mnt/bn/roboicl-jirong/codebase/DeLVM/logs/f16_256_real_calvin_datacomp_1e-5_disc_start_0_weight_0.1/checkpoint_vq_epoch_34999.tar', 8192)
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.3/checkpoint_vq_epoch_89999.tar', 2048)
# vqmodel = load_model(None, 2048)
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.3_resume/checkpoint_vq_epoch_23999.tar', 2048)
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048/checkpoint_vq_epoch_122499.tar', 2048)
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_resume_aug/checkpoint_vq_epoch_97499.tar', 2048)
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/git/DeLVM/vqgan_ckpt/ckpt.pth', 8192)
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_muse_ego4d_llava_calvin_3e-6_disc_0/checkpoint_vq_epoch_66999.tar', 8192)
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.3/checkpoint_vq_epoch_89999.tar', 2048)
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_aug_disc_0.5_resume/checkpoint_vq_epoch_999.tar', 2048)
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_192_real_data_muse_finetune_teacher_2048_resume_aug/checkpoint_vq_epoch_97499.tar', 2048)
# copied_param = []
# init_param = []
# for ((name, param), (name_s, param_s)) in zip(vq_model_t.named_parameters(), vqmodel.named_parameters()):
# if param.shape == param_s.shape:
# param_s.data = param.clone().data
# copied_param.append(param_s)
# else:
# init_param.append(name_s)
# # print (name, param.shape, name_s, param_s.shape)
# print ('params coopied from teacher')
# print (init_param)
# vqmodel = get_vqmodel(cfg)
# vqmodel = load_model('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_8192_160m_1e-4_muse_pretrained/checkpoint_vq_epoch_49999.tar')
# ckpt = torch.load('/mnt/bn/robotics-data-hl/jirong/rlhf/vq/f16_8192_160m_1e-4_muse_pretrained/checkpoint_vq_epoch_49999.tar', map_location='cpu')
# vqmodel.load_state_dict(ckpt)
vqmodel = nn.SyncBatchNorm.convert_sync_batchnorm(vqmodel)
vqmodel = vqmodel.to(device)
# with open('/mnt/bn/robotics-data-hl/jirong/git/incontextrobotics/dataset/vqgan_imgs/calvin_only.json', 'r') as f:
# tars = json.load(f)
ds_list = []
# ds0 = get_dataset('/mnt/bn/roboicllq-data1/calvin_img/calvin_img_00000051.tar', seed=42)
# ds_list += [(ds0, 0.05)]
# print (len(tars), '============')
# ds1 = get_dataset('/mnt/bn/roboicllq-data1/processed_real/imgs/real_data_img_{0000..0590}.tar', seed=42)
# ds_list += [(ds1, 0.1)]
# ds2 = get_dataset('/mnt/bn/roboicllq-data1/processed_real/hand_imgs/real_data_hand_img_{0000..0590}.tar', seed=42)
# ds_list += [(ds2, 0.02)]
ds3 = get_dataset('/mnt/bn/roboicllq-data1/calvin_img/calvin_img_{00000000..00000110}.tar', seed=42)
ds_list += [(ds3, 0.1)]
ds6 = get_dataset('/mnt/bn/roboicllq-data1/calvin_img/hands/calvin_hands_img_{00000000..00000110}.tar', seed=42)
ds_list += [(ds6, 0.02)]
with open('/mnt/bn/roboicl-jirong/codebase/RoboICL/robot_img.json', 'r') as f:
tars = json.load(f)
# # # print (len(tars))
ds4 = get_dataset(tars, seed=42)
ds_list += [(ds4, 0.3)]
with open('/mnt/bn/roboicllq-data1/aligned_robot_ds/calvin/datacomp.json', 'r') as f:
tars = json.load(f)
# # # print (len(tars))
ds0 = get_dataset(tars, seed=42)
ds_list += [(ds0, 0.68)]
ds = wds.RandomMix(*zip(*ds_list))
ds = wds.DataPipeline(ds)
loader = (
wds.WebLoader(ds, num_workers=4, batch_size=args.batch_size, pin_memory=True).with_epoch(args.epoch)
)
# loader = MyLoader(dataset=trainset, num_workers=4, batch_size=args.batch_size, pin_memory=True).with_epoch(args.epoch)
# base_lr = 4.5e-6 * args.batch_size * get_world_size()
base_lr = 1e-5
# base_lr = 3e-6
# base_lr = 1e-6
# base_lr = 1e-3
# base_lr = 5e-4
# base_lr = 1e-4
# base_lr = 5e-5
opt, _ = vqmodel.configure_optimizers(base_lr)
# ae_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt[0], T_max=args.epoch, eta_min=base_lr * 0.001)
# disc_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt[1], T_max=args.epoch, eta_min=base_lr * 0.001)
# disc_start = cfg['model']['params']['lossconfig']['params']['disc_start']
disc_start = 0
ae_opt, disc_opt = opt[0], opt[1]
ae_scheduler = CosineAnnealingWarmupRestarts(ae_opt,
first_cycle_steps=args.epoch,
cycle_mult=1,
max_lr=base_lr,
min_lr=base_lr*0.9999999,
warmup_steps=args.epoch/10,
gamma=1)
disc_scheduler = CosineAnnealingWarmupRestarts(disc_opt,
first_cycle_steps=args.epoch,
cycle_mult=1,
max_lr=base_lr,
min_lr=base_lr*0.9999999,
warmup_steps=args.epoch/10,
gamma=1)
vqmodel, ae_opt, disc_opt, ae_scheduler, disc_scheduler, loader = accelerator.prepare(vqmodel, ae_opt, disc_opt, ae_scheduler, disc_scheduler, loader)
print ('global config end ---------------------------------------------------')
log_iter = 50
acc_steps = 1
def train_one_epoch(args):
STEP_CNT = 0
DISC_STEP_CNT = 0
vqmodel.train()
stat_dict = {}
for i, data in enumerate(loader):
batch = data.to(device)
aeloss, log_dict_ae = vqmodel.module.training_step(batch, 0, device, STEP_CNT)
if STEP_CNT >= disc_start:
discloss, log_dict_disc = vqmodel.module.training_step(batch, 1, device, STEP_CNT)
accelerator.backward(aeloss)
# if (STEP_CNT + 1) == 0:
ae_opt.step()
ae_opt.zero_grad()
if STEP_CNT >= disc_start:
# print('asodkjsaoifdjosjio')
accelerator.backward(discloss)
# if (STEP_CNT + 1) == 0:
disc_opt.step()
disc_opt.zero_grad()
# STEP_CNT += 1
if (STEP_CNT + 1) % log_iter == 0:
if get_rank() == 0:
for k, v in log_dict_ae.items():
if k not in stat_dict.keys():
stat_dict[k] = 0
stat_dict[k] += v.cpu().item()
if STEP_CNT >= disc_start:
for k, v in log_dict_disc.items():
if k not in stat_dict.keys():
stat_dict[k] = 0
stat_dict[k] += v.cpu().item()
for k, v in stat_dict.items():
WRITER.add_scalar(k, v/log_iter, STEP_CNT*args.batch_size*get_world_size())
stat_dict[k] = 0
WRITER.add_scalar('lr_ae', ae_opt.param_groups[0]['lr'], STEP_CNT*args.batch_size*get_world_size())
WRITER.add_scalar('lr_disc', disc_opt.param_groups[0]['lr'], STEP_CNT*args.batch_size*get_world_size())
STEP_CNT += 1
# i += 1
ae_scheduler.step(STEP_CNT)
if STEP_CNT >= disc_start:
DISC_STEP_CNT += 1
disc_scheduler.step(STEP_CNT)
if (STEP_CNT + 1) % (log_iter*1000) == 0:
save_dict = {}
if get_rank() == 0:
unwrapped_model = accelerator.unwrap_model(vqmodel)
accelerator.save(unwrapped_model.state_dict(), os.path.join(log_path, 'checkpoint_vq_epoch_' +str(STEP_CNT)+'.tar'))
# barrier()
accelerator.wait_for_everyone()
print ('epoch {} train done'.format(STEP_CNT))
def main(args):
train_one_epoch(args)
if __name__ == '__main__':
import time
# train_one_epoch()
main(args)
time.sleep(60)
dist.destroy_process_group()
print ('train done!')