| """Train streaming motion generation model (MotionStreamer) with llama blocks, Two-Forward strategy and QK-Norm, using the motion latents encoded by the Causal TAE (trained in the first stage).""" |
|
|
| import os |
| import torch |
| import numpy as np |
| import random |
| from torch.utils.tensorboard import SummaryWriter |
| import json |
| from accelerate import Accelerator |
| from models.llama_model import LLaMAHF, LLaMAHFConfig |
| import options.option_transformer as option_trans |
| import utils.utils_model as utils_model |
| import warnings |
| from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR |
| warnings.filterwarnings('ignore') |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| |
| args = option_trans.get_args_parser() |
| torch.manual_seed(args.seed) |
|
|
| def unwrap(m): |
| return m.module if hasattr(m, 'module') else m |
|
|
| |
| class WarmupCosineDecayScheduler: |
| def __init__(self, optimizer, warmup_iters, total_iters, min_lr=0): |
| self.optimizer = optimizer |
| self.warmup_iters = warmup_iters |
| self.total_iters = total_iters |
| self.min_lr = min_lr |
| self.warmup_scheduler = LambdaLR(optimizer, lr_lambda=self.warmup_lambda) |
| self.cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_iters - warmup_iters, eta_min=min_lr) |
|
|
| def warmup_lambda(self, current_iter): |
| if current_iter < self.warmup_iters: |
| return float(current_iter) / float(max(1, self.warmup_iters)) |
| return 1.0 |
|
|
| def step(self, current_iter): |
| if current_iter < self.warmup_iters: |
| self.warmup_scheduler.step() |
| else: |
| self.cosine_scheduler.step() |
|
|
| def state_dict(self): |
| return {'warmup_iters': self.warmup_iters, 'total_iters': self.total_iters, 'min_lr': self.min_lr} |
|
|
| def load_state_dict(self, state_dict): |
| self.warmup_iters = state_dict['warmup_iters'] |
| self.total_iters = state_dict['total_iters'] |
| self.min_lr = state_dict['min_lr'] |
|
|
| args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') |
| os.makedirs(args.out_dir, exist_ok=True) |
|
|
| |
| accelerator = Accelerator() |
| comp_device = accelerator.device |
|
|
| |
| logger = utils_model.get_logger(args.out_dir) |
| writer = SummaryWriter(args.out_dir) |
| logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) |
|
|
| |
| from humanml3d_272 import dataset_TM_train_motionstreamer |
| train_loader = dataset_TM_train_motionstreamer.DATALoader( |
| args.dataname, args.batch_size, unit_length=2**args.down_t, latent_dir=args.latent_dir |
| ) |
|
|
| |
| from sentence_transformers import SentenceTransformer |
| t5_model = SentenceTransformer("sentence-t5-xl", device=comp_device) |
| t5_model.half() |
| t5_model.eval() |
| for p in t5_model.parameters(): |
| p.requires_grad = False |
|
|
| |
| config = LLaMAHFConfig.from_name('Normal_size') |
| |
| |
|
|
| trans_encoder = LLaMAHF( |
| config=config, |
| num_diffusion_head_layers=args.num_diffusion_head_layers, |
| input_token_dim=args.latent_dim, |
| device=comp_device, |
| |
| |
| ) |
|
|
| if args.resume_trans is not None: |
| print('loading transformer checkpoint from {}'.format(args.resume_trans)) |
| ckpt = torch.load(args.resume_trans, map_location='cpu') |
| new_ckpt_trans = {} |
| for key in ckpt['trans'].keys(): |
| new_key = '.'.join(key.split('.')[1:]) if key.split('.')[0]=='module' else key |
| new_ckpt_trans[new_key] = ckpt['trans'][key] |
| trans_encoder.load_state_dict(new_ckpt_trans, strict=True) |
|
|
| trans_encoder.train() |
| trans_encoder.to(comp_device) |
|
|
| |
| optimizer = utils_model.initial_optim(args.decay_option, args.lr, args.weight_decay, trans_encoder, args.optimizer) |
| scheduler = WarmupCosineDecayScheduler(optimizer, args.total_iter//10, args.total_iter) |
|
|
| t5_model, trans_encoder, optimizer, train_loader = accelerator.prepare( |
| t5_model, trans_encoder, optimizer, train_loader |
| ) |
| base = accelerator.unwrap_model(trans_encoder) |
| train_loader_iter = dataset_TM_train_motionstreamer.cycle(train_loader) |
|
|
| args.dit_window = 2 |
|
|
| def lengths_to_mask(lengths, max_len): |
| return torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) |
|
|
| import math |
| def cosine_decay(step, total_steps, start_value=1.0, end_value=0.0): |
| step = torch.tensor(step, dtype=torch.float32) |
| total_steps = torch.tensor(total_steps, dtype=torch.float32) |
| cosine_factor = 0.5 * (1 + torch.cos(torch.pi * step / total_steps)) |
| return start_value + (end_value - start_value) * cosine_factor |
|
|
| def replace_with_pred(latents, pred_xstart, step, total_steps): |
| decay_factor = cosine_decay(step, total_steps).to(latents.device) |
| b, l, d = latents.shape |
| num_replace = int(l * decay_factor) |
| replace_indices = torch.randperm(l, device=latents.device)[:num_replace] |
| replace_mask = torch.zeros(b, l, dtype=torch.bool, device=latents.device) |
| replace_mask[:, replace_indices] = 1 |
| updated_latents = latents.clone() |
| updated_latents[replace_mask] = pred_xstart[replace_mask] |
| return updated_latents |
|
|
| |
| def forward_loss_withmask_2_forward_streaming(latents, trans, m_lens, feat_text, |
| step, total_steps, A_token_length, K=None): |
| """ |
| Two-Forward with a *windowed* Temporal-DiT: |
| - AR sees full sequence. |
| - Diffusion head sees only last K positions (causal). |
| """ |
| K = K or getattr(args, "dit_window", 2) |
|
|
| latents = latents.to(comp_device) |
| feat_text = feat_text.to(comp_device) |
| A_token_length = A_token_length.to(comp_device) |
|
|
| B, L, D = latents.shape |
| L_eff = L - 1 |
| if L_eff <= 0: |
| raise ValueError("Sequence too short for next-token training.") |
|
|
| base.set_prompt(feat_text) |
|
|
| |
| conditions = trans(latents, feature=None) |
| |
| z_full = conditions[:, 1:-1, :] |
| target_full = latents[:, 1:, :] |
|
|
| |
| eff_lens = (m_lens - 1).clamp(min=0) |
| full_mask = torch.arange(L_eff, device=latents.device).unsqueeze(0).expand(B, L_eff) < eff_lens.unsqueeze(1) |
| |
| for b in range(B): |
| a_excl = max(0, A_token_length[b].item() - 1) |
| if a_excl > 0: |
| full_mask[b, :a_excl] = False |
|
|
| |
| W = min(K, L_eff) |
| tail_start = L_eff - W |
| z = z_full[:, tail_start:, :] |
| target = target_full[:, tail_start:, :] |
| mask = full_mask[:, tail_start:] |
| mask_flat = mask.reshape(B * W).float() |
|
|
| |
| base.diff_loss.set_sequence_layout(B, W) |
|
|
| |
| with torch.no_grad(): |
| |
| loss0, pred_xstart_full = base.diff_loss( |
| target=target.reshape(B * W, D), |
| z=z.reshape(B * W, -1), |
| mask=None |
| ) |
| pred_xstart = pred_xstart_full.view(B, W, D) |
|
|
| |
| for b in range(B): |
| a_excl = max(0, A_token_length[b].item() - 1) |
| |
| |
| cut = max(0, min(W, a_excl - tail_start)) |
| if cut > 0: |
| pred_xstart[b, :cut, :] = target[b, :cut, :] |
|
|
| |
| decay_ratio = 0.5 * (1.0 + torch.cos( |
| torch.pi * torch.tensor(step, dtype=torch.float32, device=latents.device) |
| / torch.tensor(total_steps, dtype=torch.float32, device=latents.device) |
| )).item() |
| k = int(W * decay_ratio) |
|
|
| updated_latents = latents.clone() |
| if k > 0: |
| replace_idx = torch.randperm(W, device=latents.device)[:k] |
| |
| raw_positions = 1 + tail_start + replace_idx |
| |
| updated_latents[:, raw_positions, :] = pred_xstart[:, replace_idx, :] |
|
|
| |
| updated_conditions = trans(updated_latents, feature=None) |
| updated_z_full = updated_conditions[:, 1:-1, :] |
| updated_z = updated_z_full[:, tail_start:, :] |
|
|
| updated_loss, _ = base.diff_loss( |
| target=target.reshape(B * W, D), |
| z=updated_z.reshape(B * W, -1), |
| mask=mask_flat |
| ) |
| return updated_loss |
|
|
| |
| nb_iter, avg_loss_cls = 0, 0.0 |
|
|
| while nb_iter <= args.total_iter: |
| batch = next(train_loader_iter) |
| caption, m_tokens, m_tokens_len, A_token_length = batch |
| caption = list(caption) |
| m_tokens, m_tokens_len = m_tokens.to(comp_device), m_tokens_len.to(comp_device) |
| A_token_length = A_token_length.to(comp_device) |
|
|
| |
| bs = len(caption) |
| num_masked = int(bs * 0.1) |
| if num_masked > 0: |
| for idx in random.sample(range(bs), num_masked): |
| caption[idx] = '' |
|
|
| |
| feat_text = torch.from_numpy(t5_model.encode(caption)).float().to(comp_device) |
|
|
| |
| input_latent = m_tokens[:, :-1, :] |
|
|
| loss_cls = forward_loss_withmask_2_forward_streaming( |
| latents=input_latent, |
| trans=trans_encoder, |
| m_lens=m_tokens_len, |
| feat_text=feat_text, |
| step=nb_iter, |
| total_steps=args.total_iter, |
| A_token_length=A_token_length, |
| K=args.dit_window, |
| ) |
|
|
| |
| optimizer.zero_grad() |
| accelerator.backward(loss_cls) |
| optimizer.step() |
| scheduler.step(nb_iter) |
|
|
| avg_loss_cls += loss_cls.item() |
| nb_iter += 1 |
|
|
| |
| args.print_iter = 100 |
| if nb_iter % args.print_iter == 0: |
| if accelerator.is_main_process: |
| avg_loss_cls = avg_loss_cls / args.print_iter |
| writer.add_scalar('./Loss/train', avg_loss_cls, nb_iter) |
| writer.add_scalar('./LR/train', optimizer.param_groups[0]['lr'], nb_iter) |
| logger.info(f"Train. Iter {nb_iter} : Loss. {avg_loss_cls:.5f}") |
| avg_loss_cls = 0.0 |
|
|
| |
| args.save_iter = 10000 |
| if nb_iter % args.save_iter == 0: |
| if accelerator.is_main_process: |
| torch.save({'trans': unwrap(trans_encoder).state_dict()}, |
| os.path.join(args.out_dir, f'latest.pth')) |
|
|
| accelerator.wait_for_everyone() |
|
|