| """Train original text to motion generation model with llama blocks, Two-Forward strategy and QK-Norm, using the motion latents encoded by the Causal TAE (trained in the first stage).""" |
|
|
| import os |
| import torch |
| import random |
| from torch.utils.tensorboard import SummaryWriter |
| import json |
| from accelerate import Accelerator |
|
|
| from models.llama_model import LLaMAHF, LLaMAHFConfig |
| from humanml3d_272 import dataset_TM_train |
| import options.option_transformer as option_trans |
| import utils.utils_model as utils_model |
| import warnings |
| from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR |
| warnings.filterwarnings('ignore') |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| |
| args = option_trans.get_args_parser() |
| torch.manual_seed(args.seed) |
|
|
| |
| class WarmupCosineDecayScheduler: |
| def __init__(self, optimizer, warmup_iters, total_iters, min_lr=0): |
| self.optimizer = optimizer |
| self.warmup_iters = warmup_iters |
| self.total_iters = total_iters |
| self.min_lr = min_lr |
| |
| self.warmup_scheduler = LambdaLR(optimizer, lr_lambda=self.warmup_lambda) |
| |
| self.cosine_scheduler = CosineAnnealingLR(optimizer, |
| T_max=total_iters - warmup_iters, |
| eta_min=min_lr) |
| |
| def warmup_lambda(self, current_iter): |
| if current_iter < self.warmup_iters: |
| return float(current_iter) / float(max(1, self.warmup_iters)) |
| return 1.0 |
|
|
| def step(self, current_iter): |
| if current_iter < self.warmup_iters: |
| self.warmup_scheduler.step() |
| else: |
| self.cosine_scheduler.step() |
|
|
| def state_dict(self): |
| return { |
| 'warmup_iters': self.warmup_iters, |
| 'total_iters': self.total_iters, |
| 'min_lr': self.min_lr, |
| } |
|
|
| def load_state_dict(self, state_dict): |
| self.warmup_iters = state_dict['warmup_iters'] |
| self.total_iters = state_dict['total_iters'] |
| self.min_lr = state_dict['min_lr'] |
|
|
|
|
| args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') |
| os.makedirs(args.out_dir, exist_ok = True) |
|
|
|
|
| |
| accelerator = Accelerator() |
| comp_device = accelerator.device |
|
|
| |
| logger = utils_model.get_logger(args.out_dir) |
| writer = SummaryWriter(args.out_dir) |
| logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) |
|
|
|
|
| |
| train_loader = dataset_TM_train.DATALoader(args.dataname, args.batch_size, args.latent_dir, unit_length=2**args.down_t) |
|
|
|
|
| |
| from sentence_transformers import SentenceTransformer |
| t5_model = SentenceTransformer('sentencet5-xxl/') |
| t5_model.eval() |
| for p in t5_model.parameters(): |
| p.requires_grad = False |
|
|
|
|
| config = LLaMAHFConfig.from_name('Normal_size') |
| config.block_size = 78 |
| trans_encoder = LLaMAHF(config, args.num_diffusion_head_layers, args.latent_dim, comp_device) |
|
|
| if args.resume_trans is not None: |
| print('loading transformer checkpoint from {}'.format(args.resume_trans)) |
| ckpt = torch.load(args.resume_trans, map_location='cpu') |
| new_ckpt_trans = {} |
| for key in ckpt['trans'].keys(): |
| if key.split('.')[0]=='module': |
| new_key = '.'.join(key.split('.')[1:]) |
| else: |
| new_key = key |
| new_ckpt_trans[new_key] = ckpt['trans'][key] |
| trans_encoder.load_state_dict(new_ckpt_trans, strict=True) |
| trans_encoder.train() |
| trans_encoder.to(comp_device) |
|
|
|
|
| |
| optimizer = utils_model.initial_optim(args.decay_option, args.lr, args.weight_decay, trans_encoder, args.optimizer) |
| scheduler = WarmupCosineDecayScheduler(optimizer, args.total_iter//10, args.total_iter) |
|
|
|
|
| t5_model, trans_encoder, optimizer, train_loader = accelerator.prepare(t5_model, trans_encoder, optimizer, train_loader) |
| train_loader_iter = dataset_TM_train.cycle(train_loader) |
|
|
|
|
| diffmlps_batch_mul = 4 |
| def lengths_to_mask(lengths, max_len): |
| mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) |
| return mask |
| def get_mask_subset_prob(mask, prob): |
| subset_mask = torch.bernoulli(mask, p=prob) & mask |
| return subset_mask |
|
|
|
|
| def uniform(shape, device=None): |
| return torch.zeros(shape, device=device).float().uniform_(0, 1) |
|
|
| import math |
| def cosine_schedule(t): |
| return torch.cos(t * math.pi * 0.5) |
|
|
|
|
| |
| def cosine_decay(step, total_steps, start_value=1.0, end_value=0.0): |
|
|
| step = torch.tensor(step, dtype=torch.float32) |
| total_steps = torch.tensor(total_steps, dtype=torch.float32) |
| |
| cosine_factor = 0.5 * (1 + torch.cos(torch.pi * step / total_steps)) |
| return start_value + (end_value - start_value) * cosine_factor |
|
|
| def replace_with_pred(latents, pred_xstart, step, total_steps): |
| |
| decay_factor = cosine_decay(step, total_steps).to(latents.device) |
| |
| b, l, d = latents.shape |
| num_replace = int(l * decay_factor) |
| |
| replace_indices = torch.randperm(l)[:num_replace] |
|
|
| replace_mask = torch.zeros(b, l, dtype=torch.bool).to(latents.device) |
| replace_mask[:, replace_indices] = 1 |
|
|
| updated_latents = latents.clone() |
| updated_latents[replace_mask] = pred_xstart[replace_mask] |
| |
| return updated_latents |
|
|
| def forward_loss_withmask_2_forward(latents, trans, m_lens, feat_text, step, total_steps): |
| """z: condition; latents: gt""" |
| |
| conditions = trans(latents, feat_text) |
| conditions = conditions.contiguous() |
| z = conditions[:,:-1,:] |
| |
|
|
| b, l, d = latents.shape |
| mask = lengths_to_mask(m_lens, l) |
| mask = mask.reshape(b * l).repeat(diffmlps_batch_mul) |
|
|
| target = latents.clone().detach() |
| target = target.reshape(b * l, -1) |
| z = z.reshape(b * l, -1) |
| |
| with torch.no_grad(): |
| loss, pred_xstart = trans.diff_loss(target=target, z=z) |
|
|
| pred_xstart = pred_xstart.clone().detach() |
| pred_xstart = pred_xstart.reshape(b, l, -1) |
|
|
| |
| |
| updated_latents = replace_with_pred(latents, pred_xstart, step, total_steps) |
| updated_conditions = trans(updated_latents, feat_text) |
| updated_conditions = updated_conditions.contiguous() |
| updated_z = updated_conditions[:,:-1,:] |
|
|
| updated_target = latents.clone().detach() |
|
|
| updated_target = updated_target.reshape(b * l, -1).repeat(diffmlps_batch_mul, 1) |
| updated_z = updated_z.reshape(b * l, -1).repeat(diffmlps_batch_mul, 1) |
|
|
| updated_target = updated_target[mask] |
| updated_z = updated_z[mask] |
|
|
| updated_loss, _ = trans.diff_loss(target=updated_target, z=updated_z) |
|
|
| return updated_loss |
| |
|
|
| |
| nb_iter, avg_loss = 0, 0. |
|
|
| while nb_iter <= args.total_iter: |
| batch = next(train_loader_iter) |
| text, m_tokens, m_tokens_len = batch |
| text = list(text) |
| m_tokens, m_tokens_len = m_tokens.to(comp_device), m_tokens_len.to(comp_device) |
|
|
| bs = len(text) |
| num_masked = int(bs * 0.1) |
| mask_indices = random.sample(range(bs), num_masked) |
|
|
| for idx in mask_indices: |
| text[idx] = '' |
|
|
| feat_text = torch.from_numpy(t5_model.encode(text)).float() |
| feat_text = feat_text.to(comp_device) |
|
|
| |
| input_latent = m_tokens[:,:-1] |
| loss = 0.0 |
|
|
| if args.num_gpus > 1: |
| loss = forward_loss_withmask_2_forward(latents=input_latent, trans=trans_encoder.module, m_lens = m_tokens_len, feat_text=feat_text, step=nb_iter, total_steps=args.total_iter) |
| else: |
| loss = forward_loss_withmask_2_forward(latents=input_latent, trans=trans_encoder, m_lens = m_tokens_len, feat_text=feat_text, step=nb_iter, total_steps=args.total_iter) |
|
|
| |
| optimizer.zero_grad() |
| accelerator.backward(loss) |
| optimizer.step() |
| scheduler.step(nb_iter) |
|
|
| avg_loss = avg_loss + loss.item() |
|
|
| nb_iter += 1 |
| args.print_iter = 100 |
| if nb_iter % args.print_iter == 0 : |
| if accelerator.is_main_process: |
| avg_loss = avg_loss / args.print_iter |
| writer.add_scalar('./Loss/train', avg_loss, nb_iter) |
| writer.add_scalar('./LR/train', optimizer.param_groups[0]['lr'], nb_iter) |
| msg = f"Train. Iter {nb_iter} : Loss. {avg_loss:.5f}" |
| logger.info(msg) |
| avg_loss = 0. |
|
|
|
|
| args.save_iter = 10000 |
| if nb_iter % args.save_iter == 0: |
| |
| if accelerator.is_main_process: |
| torch.save({ |
| 'trans': trans_encoder.state_dict(), |
| 'scheduler': scheduler.state_dict(), |
| 'optimizer': optimizer.state_dict() |
| }, os.path.join(args.out_dir, f'latest.pth')) |
|
|
| |
|
|
| accelerator.wait_for_everyone() |
|
|