|
|
| import math |
| from dataclasses import dataclass |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
| from typing_extensions import Self |
| from typing import Optional |
| from transformers.modeling_utils import PreTrainedModel |
| from torch.distributions import Categorical |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class LLaMAHFConfig: |
| block_size: int = 78 |
| n_layer: int = 32 |
| n_head: int = 32 |
| n_embd: int = 4096 |
| T5_xxl_dim: int = 768 |
|
|
| @classmethod |
| def from_name(cls, name: str) -> Self: |
| return cls(**llama_configs[name]) |
|
|
|
|
| llama_configs = { |
| "Normal_size": dict(n_layer=12, n_head=12, n_embd=768) |
| } |
|
|
|
|
| class LLaMAHF(nn.Module): |
| def __init__(self, config: LLaMAHFConfig, num_diffusion_head_layers=9, input_token_dim=16, device=torch.device('cuda'), width=1792) -> None: |
| super().__init__() |
| assert config.block_size is not None |
| self.config = config |
|
|
| cond_dim = config.T5_xxl_dim |
| |
| self.transformer = nn.ModuleDict( |
| dict( |
| wte=nn.Linear(input_token_dim, config.n_embd), |
| cond_embed=nn.Linear(cond_dim, config.n_embd), |
| h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
| ln_f=RMSNorm(config.n_embd), |
| ) |
| ) |
| |
| target_channels = input_token_dim |
| from models.diffloss import DiffLoss |
| self.diff_loss = DiffLoss( |
| target_channels=target_channels, |
| z_channels=config.n_embd, |
| width=width, |
| depth=num_diffusion_head_layers, |
| num_sampling_steps='50', |
| grad_checkpointing=False, |
| ) |
| self.diff_loss = self.diff_loss.to(device) |
| self.out_proj = nn.Linear(config.n_embd, config.n_embd) |
| self.use_out_proj = True |
| |
|
|
| def _tie_or_clone_weights(self, output_embeddings, input_embeddings): |
| """Tie or clone module weights depending of whether we are using TorchScript or not""" |
| output_embeddings.weight = input_embeddings.weight |
|
|
| if getattr(output_embeddings, "bias", None) is not None: |
| output_embeddings.bias.data = nn.functional.pad( |
| output_embeddings.bias.data, |
| ( |
| 0, |
| output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0], |
| ), |
| "constant", |
| 0, |
| ) |
| if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): |
| output_embeddings.out_features = input_embeddings.num_embeddings |
|
|
| def get_input_embeddings(self): |
| return self.transformer.wte |
| |
| def set_input_embeddings(self, value): |
| self.transformer.wte = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
| |
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) |
| |
| |
|
|
| def forward_sample(self, idx: torch.Tensor, clip_feature: torch.Tensor, y_mask) -> torch.Tensor: |
| |
| text_length = clip_feature.shape[1] |
| if len(idx) == 0: |
| x = self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :] |
| else: |
| _, t = idx.size() |
| assert ( |
| t <= self.config.block_size |
| ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
| |
| x = self.transformer.wte(idx) |
| x = torch.cat((self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :],x), dim=1) |
|
|
| for block in self.transformer.h: |
| x = block(x, y_mask) |
| x = self.transformer.ln_f(x) |
| logits = x |
| return logits |
|
|
| |
|
|
| def sample_for_eval_CFG(self, text, length=196, tokenize_model=None, device=torch.device('cuda'), unit_length=4, cfg=4.0): |
| max_token_len = length // unit_length |
| for k in range(max_token_len): |
| if k == 0: |
| x = [] |
| else: |
| x = xs |
| |
| feat_text = torch.from_numpy(tokenize_model.encode(text)).float() |
| feat_text = feat_text.to(device) |
| conditions = self.forward(x, feat_text) |
| conditions = conditions[:, -1, :] |
| |
| empty_text = '' |
| empty_feat_text = torch.from_numpy(tokenize_model.encode(empty_text)).float() |
| empty_feat_text = empty_feat_text.unsqueeze(0) |
| empty_feat_text = empty_feat_text.to(device) |
| empty_conditions = self.forward(x, empty_feat_text) |
| empty_conditions = empty_conditions[:, -1, :] |
| temperature = 1.0 |
| |
| |
| if cfg != 1: |
| mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
| sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
| scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
| else: |
| scaled_logits = self.diff_loss.sample(conditions, temperature=temperature, cfg=1) |
|
|
| scaled_logits = scaled_logits.unsqueeze(0) |
| |
| if k == 0: |
| xs = scaled_logits |
| else: |
| xs = torch.cat((xs, scaled_logits), dim=1) |
|
|
| return xs |
| |
| |
| |
| |
| def sample_for_eval_CFG_inference(self, text, length=312, tokenizer=None, device=torch.device('cuda'), unit_length=4, reference_end_latent=None, threshold=0.1, cfg=4.0, temperature=1.0): |
| max_token_len = length // unit_length |
| feat_text = torch.from_numpy(tokenizer.encode(text)).float() |
| feat_text = feat_text.to(device) |
|
|
| |
| empty_text = '' |
| empty_feat_text = torch.from_numpy(tokenizer.encode(empty_text)).float() |
| empty_feat_text = empty_feat_text.unsqueeze(0) |
| empty_feat_text = empty_feat_text.to(device) |
| |
| for k in range(max_token_len): |
| if k == 0: |
| x = [] |
| else: |
| x = xs |
| |
| conditions = self.forward_inference(x, feat_text) |
| conditions = conditions[:, -1, :] |
|
|
| empty_conditions = self.forward(x, empty_feat_text) |
| empty_conditions = empty_conditions[:, -1, :] |
| |
| mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
| sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
| |
| if cfg != 1: |
| scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
| else: |
| scaled_logits = sampled_token_latent |
|
|
| scaled_logits = scaled_logits.unsqueeze(0) |
| |
| if reference_end_latent is not None: |
| distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_latent)**2)) |
| print(distance_l2) |
| if distance_l2 < threshold: |
| break |
|
|
| if k == 0: |
| xs = scaled_logits |
| else: |
| xs = torch.cat((xs, scaled_logits), dim=1) |
|
|
| return xs |
|
|
| |
| def sample_for_eval_CFG_inference2(self, feat_clip_text, empty_feat_clip_text, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0): |
| |
| import clip |
| max_token_len = length // unit_length |
| |
| for k in range(max_token_len): |
| if k == 0: |
| x = [] |
| else: |
| x = xs |
|
|
| try: |
| conditions = self.forward(x, feat_clip_text) |
| except: |
| conditions = self.forward(x, feat_clip_text.unsqueeze(0)) |
| |
| |
| conditions = conditions[:, -1, :] |
| |
| |
|
|
| empty_conditions = self.forward(x, empty_feat_clip_text) |
| empty_conditions = empty_conditions[:, -1, :] |
| |
| mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
| sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
| |
| if cfg != 1: |
| scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
| else: |
| scaled_logits = sampled_token_latent |
|
|
| scaled_logits = scaled_logits.unsqueeze(0) |
| |
| if reference_end_token is not None: |
| distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2)) |
| print(distance_l2) |
| if distance_l2 < threshold: |
| break |
|
|
| if k == 0: |
| xs = scaled_logits |
| else: |
| xs = torch.cat((xs, scaled_logits), dim=1) |
|
|
| return xs |
|
|
| def sample_for_eval_CFG_inference_next_one(self, current_token=[], feat_clip_text=None, empty_feat_clip_text=None, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0): |
| |
| import clip |
| max_token_len = length // unit_length |
| |
| |
| for k in range(1): |
| |
| if current_token == []: |
| x = [] |
| else: |
| x = torch.cat(current_token, dim=1) |
| |
| |
| try: |
| conditions = self.forward(x, feat_clip_text) |
| except: |
| conditions = self.forward(x, feat_clip_text.unsqueeze(0)) |
| |
| |
| conditions = conditions[:, -1, :] |
| |
|
|
| empty_conditions = self.forward(x, empty_feat_clip_text) |
| empty_conditions = empty_conditions[:, -1, :] |
|
|
| mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
| sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
| |
| if cfg != 1: |
| scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
| else: |
| scaled_logits = sampled_token_latent |
|
|
|
|
| scaled_logits = scaled_logits.unsqueeze(0) |
| |
|
|
| if k == 0: |
| xs = scaled_logits |
| else: |
| xs = torch.cat((xs, scaled_logits), dim=1) |
|
|
| return xs |
| |
|
|
| def sample_for_eval_CFG_babel(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3): |
|
|
| import clip |
| B_token_length = length // unit_length - A_motion.shape[0] |
|
|
| if tokenizer == 'clip': |
| A_text = clip.tokenize(A_text, truncate=True).to(device) |
| A_feat_clip_text = clip_model.encode_text(A_text).float() |
| B_text = clip.tokenize(B_text, truncate=True).to(device) |
| B_feat_clip_text = clip_model.encode_text(B_text).float() |
| elif tokenizer == 't5-xxl': |
| A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float() |
| A_feat_clip_text = A_feat_clip_text.to(device) |
| B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float() |
| B_feat_clip_text = B_feat_clip_text.to(device) |
| |
| A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0) |
| B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0) |
|
|
| A_motion = A_motion.unsqueeze(0) |
| A_motion_embeddings = self.transformer.wte(A_motion) |
| B_motion = torch.tensor([]).to(device) |
|
|
| for k in range(B_token_length): |
| if k == 0: |
| x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1) |
| else: |
| x = xs |
|
|
| |
| conditions = self.forward_babel_eval(x) |
| conditions = conditions[:, -1, :] |
|
|
| empty_clip_text = '' |
| if tokenizer == 'clip': |
| empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) |
| empty_feat_clip_text = clip_model.encode_text(empty_text).float() |
| elif tokenizer == 't5-xxl': |
| empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float() |
| empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) |
| empty_feat_clip_text = empty_feat_clip_text.to(device) |
|
|
| empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) |
| |
| if k == 0: |
| empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1) |
| empty_conditions = self.forward_babel_eval(empty_input) |
| else: |
| B_motion_embeddings = self.transformer.wte(B_motion) |
| empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1) |
| empty_conditions = self.forward_babel_eval(empty_input) |
| |
| empty_conditions = empty_conditions[:, -1, :] |
| temperature = 1.0 |
| |
| mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
| sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
| |
| if cfg != 1: |
| scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
| else: |
| scaled_logits = sampled_token_latent |
|
|
|
|
| scaled_logits = scaled_logits.unsqueeze(0) |
| |
|
|
| B_motion = torch.cat((B_motion, scaled_logits), dim=1) |
|
|
| scaled_logits_embedding = self.transformer.wte(scaled_logits) |
| xs = torch.cat((x, scaled_logits_embedding), dim=1) |
| |
|
|
| return xs, B_motion |
|
|
| def sample_for_eval_CFG_babel_inference(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3): |
|
|
| import clip |
| B_token_length = length // unit_length - A_motion.shape[0] |
|
|
| if tokenizer == 'clip': |
| A_text = clip.tokenize(A_text, truncate=True).to(device) |
| A_feat_clip_text = clip_model.encode_text(A_text).float() |
| B_text = clip.tokenize(B_text, truncate=True).to(device) |
| B_feat_clip_text = clip_model.encode_text(B_text).float() |
| elif tokenizer == 't5-xxl': |
| A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float() |
| A_feat_clip_text = A_feat_clip_text.to(device) |
| B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float() |
| B_feat_clip_text = B_feat_clip_text.to(device) |
| |
| A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0) |
| A_text_embeddings = A_text_embeddings.unsqueeze(0) |
| B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0) |
| B_text_embeddings = B_text_embeddings.unsqueeze(0) |
|
|
| A_motion = A_motion.unsqueeze(0) |
| A_motion_embeddings = self.transformer.wte(A_motion) |
| B_motion = torch.tensor([]).to(device) |
|
|
| attention_weights = [] |
|
|
| for k in range(B_token_length): |
| if k == 0: |
| x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1) |
| |
| else: |
| x = xs |
|
|
| |
| |
| conditions = self.forward_babel_eval(x, return_attention=False) |
| conditions = conditions[:, -1, :] |
|
|
| empty_clip_text = '' |
| if tokenizer == 'clip': |
| empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) |
| empty_feat_clip_text = clip_model.encode_text(empty_text).float() |
| elif tokenizer == 't5-xxl': |
| empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float() |
| empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) |
| empty_feat_clip_text = empty_feat_clip_text.to(device) |
|
|
| empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) |
| |
| if k == 0: |
| empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1) |
| empty_conditions = self.forward_babel_eval(empty_input) |
| else: |
| B_motion_embeddings = self.transformer.wte(B_motion) |
| empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1) |
| empty_conditions = self.forward_babel_eval(empty_input) |
|
|
| empty_conditions = empty_conditions[:, -1, :] |
| temperature = 1.0 |
| |
| mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
| sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
| |
| if cfg != 1: |
| scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
| else: |
| scaled_logits = sampled_token_latent |
|
|
| scaled_logits = scaled_logits.unsqueeze(0) |
| |
| if reference_end_token is not None: |
| distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2)) |
| print(distance_l2) |
| if distance_l2 < threshold: |
| break |
| |
| B_motion = torch.cat((B_motion, scaled_logits), dim=1) |
|
|
| scaled_logits_embedding = self.transformer.wte(scaled_logits) |
| xs = torch.cat((x, scaled_logits_embedding), dim=1) |
| |
| |
| |
| return xs, B_motion |
| |
|
|
| def sample_for_eval_CFG_babel_inference_new(self, B_text, A_motion, if_categorial=False, length=78, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3): |
|
|
| import clip |
| B_token_length = length // unit_length |
|
|
| if tokenizer == 'clip': |
| A_text = clip.tokenize(A_text, truncate=True).to(device) |
| A_feat_clip_text = clip_model.encode_text(A_text).float() |
| B_text = clip.tokenize(B_text, truncate=True).to(device) |
| B_feat_clip_text = clip_model.encode_text(B_text).float() |
| elif tokenizer == 't5-xxl': |
| B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float() |
| B_feat_clip_text = B_feat_clip_text.to(device) |
|
|
| empty_clip_text = '' |
| if tokenizer == 'clip': |
| empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) |
| empty_feat_clip_text = clip_model.encode_text(empty_text).float() |
| elif tokenizer == 't5-xxl': |
| empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float() |
| empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) |
| empty_feat_clip_text = empty_feat_clip_text.to(device) |
| |
| B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0) |
|
|
| A_motion = A_motion.unsqueeze(0) |
| A_motion_embeddings = self.transformer.wte(A_motion) |
| B_motion = torch.tensor([]).to(device) |
|
|
| |
| attention_weights = [] |
|
|
| for k in range(B_token_length): |
| if k == 0: |
| x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1) |
| else: |
| x = xs |
|
|
| conditions = self.forward_babel_eval(x, return_attention=False) |
| conditions = conditions[:, -1, :] |
| |
|
|
| empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) |
| |
| if k == 0: |
| empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1) |
| |
| empty_conditions = self.forward_babel_eval(empty_input) |
| else: |
| B_motion_embeddings = self.transformer.wte(B_motion) |
| empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1) |
| empty_conditions = self.forward_babel_eval(empty_input) |
|
|
| empty_conditions = empty_conditions[:, -1, :] |
| temperature = 1.0 |
| |
| mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
| sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
| |
| if cfg != 1: |
| scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
| else: |
| scaled_logits = sampled_token_latent |
|
|
| scaled_logits = scaled_logits.unsqueeze(0) |
| |
| if reference_end_token is not None: |
| distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2)) |
| print(distance_l2) |
| if distance_l2 < threshold: |
| break |
| |
| B_motion = torch.cat((B_motion, scaled_logits), dim=1) |
|
|
| scaled_logits_embedding = self.transformer.wte(scaled_logits) |
| xs = torch.cat((x, scaled_logits_embedding), dim=1) |
| |
| |
| |
| return xs, B_motion |
|
|
|
|
| def sample_for_eval_CFG_babel_inference_new_demo(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0): |
|
|
| import clip |
| B_token_length = length // unit_length - A_motion.shape[0] |
| |
| if tokenizer == 'clip': |
| A_text = clip.tokenize(A_text, truncate=True).to(device) |
| A_feat_clip_text = clip_model.encode_text(A_text).float() |
| B_text = clip.tokenize(B_text, truncate=True).to(device) |
| B_feat_clip_text = clip_model.encode_text(B_text).float() |
| elif tokenizer == 't5-xxl': |
| B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float() |
| B_feat_clip_text = B_feat_clip_text.to(device) |
|
|
| empty_clip_text = '' |
| if tokenizer == 'clip': |
| empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) |
| empty_feat_clip_text = clip_model.encode_text(empty_text).float() |
| elif tokenizer == 't5-xxl': |
| empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float() |
| empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) |
| empty_feat_clip_text = empty_feat_clip_text.to(device) |
| |
| B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0) |
| B_text_embeddings = B_text_embeddings.unsqueeze(0) |
|
|
| A_motion = A_motion.unsqueeze(0) |
| A_motion_embeddings = self.transformer.wte(A_motion) |
| B_motion = torch.tensor([]).to(device) |
|
|
| |
| attention_weights = [] |
|
|
| for k in range(B_token_length): |
| if k == 0: |
| x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1) |
| |
| else: |
| x = xs |
|
|
| |
| conditions = self.forward_babel_eval(x, return_attention=False) |
| conditions = conditions[:, -1, :] |
| |
|
|
| empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) |
| |
| if k == 0: |
| empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1) |
| empty_conditions = self.forward_babel_eval(empty_input) |
| else: |
| B_motion_embeddings = self.transformer.wte(B_motion) |
| empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1) |
| empty_conditions = self.forward_babel_eval(empty_input) |
|
|
| empty_conditions = empty_conditions[:, -1, :] |
| |
| mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
| sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
| |
| if cfg != 1: |
| scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
| else: |
| scaled_logits = sampled_token_latent |
|
|
| scaled_logits = scaled_logits.unsqueeze(0) |
| |
| if reference_end_token is not None: |
| distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2)) |
| print(distance_l2) |
| if distance_l2 < threshold and k > 10: |
| break |
| |
| B_motion = torch.cat((B_motion, scaled_logits), dim=1) |
|
|
| scaled_logits_embedding = self.transformer.wte(scaled_logits) |
| xs = torch.cat((x, scaled_logits_embedding), dim=1) |
| |
| |
| |
| return xs, B_motion |
|
|
| def sample_for_eval_CFG_babel_inference_two_forward(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0): |
| """ |
| Inference loop that mimics the "Two-Forward" training strategy. |
| This version is rewritten for clarity and to fix dimension mismatches. |
| """ |
| import clip |
| print("\n--- [DEBUG] Entering Two-Forward Inference ---") |
|
|
| B_token_length = length // unit_length - A_motion.shape[0] |
|
|
| if tokenizer == 't5-xxl': |
| B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float().to(device) |
| else: |
| raise NotImplementedError("Only t5-xxl is supported for this function.") |
| empty_feat_clip_text = torch.from_numpy(clip_model.encode('')).float().unsqueeze(0).to(device) |
|
|
| |
| B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0).unsqueeze(0) |
| empty_text_embeddings = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) |
| print(f"[DEBUG] Initial B_text_embeddings shape: {B_text_embeddings.shape}") |
| print(f"[DEBUG] Initial empty_text_embeddings shape: {empty_text_embeddings.shape}") |
|
|
| |
| A_motion_embeddings = self.transformer.wte(A_motion.unsqueeze(0)) |
| current_motion_tokens = A_motion |
| print(f"[DEBUG] Initial A_motion_embeddings shape: {A_motion_embeddings.shape}") |
| print(f"[DEBUG] Initial current_motion_tokens shape: {current_motion_tokens.shape}") |
|
|
| for k in range(B_token_length): |
| print(f"\n--- [DEBUG] Generating Token {k+1}/{B_token_length} ---") |
|
|
| |
| if k == 0: |
| |
| step_motion_embeddings = A_motion_embeddings |
| else: |
| |
| step_motion_embeddings = self.transformer.wte(current_motion_tokens.unsqueeze(0)) |
| print(f"[DEBUG] step_motion_embeddings shape: {step_motion_embeddings.shape}") |
|
|
| |
| x_first_pass = torch.cat([B_text_embeddings, step_motion_embeddings], dim=1) |
| empty_input_first_pass = torch.cat([empty_text_embeddings, step_motion_embeddings], dim=1) |
| print(f"[DEBUG] x_first_pass shape: {x_first_pass.shape}") |
| print(f"[DEBUG] empty_input_first_pass shape: {empty_input_first_pass.shape}") |
|
|
| |
| conditions_first_pass = self.forward_babel_eval(x_first_pass, return_attention=False)[:, -1, :] |
| empty_conditions_first_pass = self.forward_babel_eval(empty_input_first_pass)[:, -1, :] |
|
|
| |
| mix_conditions_first_pass = torch.cat([conditions_first_pass, empty_conditions_first_pass], dim=0) |
| pred_xstart_rough = self.diff_loss.sample(mix_conditions_first_pass, temperature=temperature, cfg=cfg) |
| if cfg != 1: |
| pred_xstart_rough, _ = pred_xstart_rough.chunk(2, dim=0) |
| print(f"[DEBUG] pred_xstart_rough shape: {pred_xstart_rough.shape}") |
|
|
| |
| |
| updated_motion_tokens = torch.cat([current_motion_tokens, pred_xstart_rough], dim=0) |
| updated_motion_embeddings = self.transformer.wte(updated_motion_tokens.unsqueeze(0)) |
| print(f"[DEBUG] updated_motion_embeddings shape: {updated_motion_embeddings.shape}") |
|
|
| x_second_pass = torch.cat([B_text_embeddings, updated_motion_embeddings], dim=1) |
| empty_input_second_pass = torch.cat([empty_text_embeddings, updated_motion_embeddings], dim=1) |
|
|
| conditions_second_pass = self.forward_babel_eval(x_second_pass, return_attention=False)[:, -1, :] |
| empty_conditions_second_pass = self.forward_babel_eval(empty_input_second_pass)[:, -1, :] |
|
|
| |
| mix_conditions_second_pass = torch.cat([conditions_second_pass, empty_conditions_second_pass], dim=0) |
| final_token, _ = self.diff_loss.sample(mix_conditions_second_pass, temperature=temperature, cfg=cfg).chunk(2, dim=0) |
| print(f"[DEBUG] final_token shape: {final_token.shape}") |
|
|
| |
| current_motion_tokens = torch.cat([current_motion_tokens, final_token], dim=0) |
| print(f"[DEBUG] New current_motion_tokens shape: {current_motion_tokens.shape}") |
|
|
| print("\n--- [DEBUG] Finished Token Generation ---") |
| |
| B_motion = current_motion_tokens[A_motion.shape[0]:, :].unsqueeze(0) |
| return None, B_motion |
|
|
| |
| |
| def sample_for_eval_classification(self, clip_text, if_categorial=False, length=196, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4): |
| |
| import clip |
| |
| |
| for k in range(51): |
| if k == 0: |
| x = [] |
| else: |
| x = xs |
|
|
| if tokenizer == 'clip': |
| text = clip.tokenize(clip_text, truncate=True).to(device) |
|
|
| feat_clip_text = clip_model.encode_text(text).float() |
| elif tokenizer == 't5-xxl': |
| feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float() |
| |
| conditions = self.forward(x, feat_clip_text) |
| conditions = conditions[:, -1, :] |
|
|
| empty_clip_text = '' |
| if tokenizer == 'clip': |
| empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) |
| empty_feat_clip_text = clip_model.encode_text(empty_text).float() |
| elif tokenizer == 't5-xxl': |
| empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float() |
| empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) |
| empty_feat_clip_text = empty_feat_clip_text.to(device) |
|
|
| empty_conditions = self.forward(x, empty_feat_clip_text) |
| empty_conditions = empty_conditions[:, -1, :] |
| |
| temperature = 1.0 |
| cfg = 7.5 |
| |
| mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
| sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
| |
| if cfg != 1: |
| scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
| else: |
| scaled_logits = sampled_token_latent |
|
|
|
|
| prediction_logits = self.classify_head(conditions) |
| probs = torch.sigmoid(prediction_logits) |
| predicted_classes = torch.argmax(probs, dim=-1) |
| |
|
|
| scaled_logits = scaled_logits.unsqueeze(0) |
| |
| if k == 0: |
| xs = scaled_logits |
| else: |
| xs = torch.cat((xs, scaled_logits), dim=1) |
|
|
| if predicted_classes == 1: |
| break |
|
|
| return xs |
| |
|
|
| |
| def sample_for_eval_CFG_test(self, clip_text, if_categorial=False, length=196, clip_model=None, cfg=1, device=torch.device('cuda'), tokenizer='clip', unit_length=4): |
|
|
| import clip |
| max_token_len = length // unit_length |
| |
| |
| for k in range(max_token_len): |
| if k == 0: |
| x = [] |
| else: |
| x = xs |
|
|
| |
| if cfg != 1: |
| if tokenizer == 'clip': |
| text = clip.tokenize(clip_text, truncate=True).to(device) |
|
|
| feat_clip_text = clip_model.encode_text(text).float() |
| elif tokenizer == 't5-xxl': |
| feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float() |
| |
| conditions = self.forward(x, feat_clip_text) |
| |
| conditions = conditions[:, -1, :] |
| empty_clip_text = '' |
| if tokenizer == 'clip': |
| empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device) |
| empty_feat_clip_text = clip_model.encode_text(empty_text).float() |
| elif tokenizer == 't5-xxl': |
| empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float() |
| empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0) |
| empty_feat_clip_text = empty_feat_clip_text.to(device) |
|
|
| empty_conditions = self.forward(x, empty_feat_clip_text) |
| empty_conditions = empty_conditions[:, -1, :] |
| temperature = 1.0 |
| |
| |
| mix_conditions = torch.cat([conditions, empty_conditions], dim=0) |
| sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg) |
|
|
| |
| scaled_logits, _ = sampled_token_latent.chunk(2, dim=0) |
| |
| else: |
| if tokenizer == 'clip': |
| text = clip.tokenize(clip_text, truncate=True).to(device) |
| feat_clip_text = clip_model.encode_text(text).float() |
| elif tokenizer == 't5-xxl': |
| feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float() |
| feat_clip_text = feat_clip_text.to(device) |
|
|
|
|
| conditions = self.forward(x, feat_clip_text) |
| |
| conditions = conditions[:, -1, :] |
| temperature = 1.0 |
| sampled_token_latent = self.diff_loss.sample(conditions, temperature=temperature, cfg=cfg) |
| scaled_logits = sampled_token_latent |
|
|
| scaled_logits = scaled_logits.unsqueeze(0) |
| |
| if k == 0: |
| xs = scaled_logits |
| else: |
| xs = torch.cat((xs, scaled_logits), dim=1) |
|
|
| return xs |
| |
|
|
| def forward_discrete(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None) -> torch.Tensor: |
| if len(idx) == 0: |
| token_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(0) |
|
|
| else: |
| b, t = idx.size() |
| |
| assert ( |
| t <= self.config.block_size |
| ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
| |
| token_embeddings = self.transformer.wte(idx) |
| text_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(1) |
| token_embeddings = torch.cat([text_embeddings, token_embeddings], dim=1) |
| |
| x = token_embeddings |
|
|
| |
| |
| if use_cache: |
| if past_key_values is None: |
| past_key_values = [None] * len(self.transformer.h) |
| |
|
|
| for i,block in enumerate(self.transformer.h): |
| if use_cache: |
| last_past = past_key_values[i] |
| x, presents = block(x, last_past, use_cache) |
| past_key_values[i] = list(presents) |
| else: |
| x = block(x) |
| x = self.transformer.ln_f(x) |
|
|
| logits = self.lm_head(x) |
| |
|
|
| return logits |
|
|
|
|
| def forward(self, idx: torch.Tensor, feature: torch.Tensor) -> torch.Tensor: |
| if len(idx) == 0: |
| token_embeddings = self.transformer.cond_embed(feature).unsqueeze(0) |
|
|
| else: |
| b, t, c = idx.size() |
| idx = idx.float() |
| assert ( |
| t <= self.config.block_size |
| ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
| |
| token_embeddings = self.transformer.wte(idx) |
| text_embeddings = self.transformer.cond_embed(feature).unsqueeze(1) |
| token_embeddings = torch.cat([text_embeddings, token_embeddings], dim=1) |
| |
| x = token_embeddings |
|
|
| for i,block in enumerate(self.transformer.h): |
| x = block(x) |
| x = self.transformer.ln_f(x) |
| logits = self.out_proj(x) |
| return logits |
|
|
|
|
| def forward_inference(self, idx: torch.Tensor, feature: torch.Tensor) -> torch.Tensor: |
| if len(idx) == 0: |
| token_embeddings = self.transformer.cond_embed(feature).unsqueeze(0) |
|
|
| else: |
| b, t, c = idx.size() |
| idx = idx.float() |
| assert ( |
| t <= self.config.block_size |
| ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
| |
| token_embeddings = self.transformer.wte(idx) |
| text_embeddings = self.transformer.cond_embed(feature).unsqueeze(0) |
| token_embeddings = torch.cat([text_embeddings.unsqueeze(0), token_embeddings], dim=1) |
| |
| x = token_embeddings |
|
|
| if len(x.shape) == 2: |
| x = x.unsqueeze(0) |
|
|
| for i,block in enumerate(self.transformer.h): |
| x = block(x) |
| x = self.transformer.ln_f(x) |
| logits = self.out_proj(x) |
| return logits |
| |
|
|
| def babel_long(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None, num_subseq=None, length=None) -> torch.Tensor: |
| |
| b, t, c = idx.size() |
| idx = idx.float() |
| idx = self.transformer.wte(idx) |
| assert ( |
| t <= self.config.block_size |
| ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
| for i in range(b): |
| length_i = length[i][:num_subseq[i]] |
| clip_feature_i = clip_feature[i][:num_subseq[i]] |
|
|
| pointer = 0 |
| for j in range(num_subseq[i]): |
| if j > 0: |
| pointer += length_i[j].item() |
| pointer += 1 |
| pointer = int(pointer) |
|
|
| clip_feature_i_j = self.transformer.cond_embed(clip_feature_i[j].unsqueeze(0)).unsqueeze(1) |
| idx[i] = torch.cat([idx[i][:pointer].unsqueeze(0), clip_feature_i_j, idx[i][pointer:-1].unsqueeze(0)], dim=1)[0] |
| |
| x = idx |
|
|
|
|
| if use_cache: |
| if past_key_values is None: |
| past_key_values = [None] * len(self.transformer.h) |
| |
|
|
| for i,block in enumerate(self.transformer.h): |
| if use_cache: |
| last_past = past_key_values[i] |
| x, presents = block(x, last_past, use_cache) |
| past_key_values[i] = list(presents) |
| else: |
| x = block(x) |
| x = self.transformer.ln_f(x) |
|
|
| logits = self.out_proj(x) |
| return logits |
| |
|
|
| def forward_babel_eval(self, x, return_attention=False) -> torch.Tensor: |
| layer_attentions = [] |
| for block in self.transformer.h: |
| if return_attention: |
| x, att = block(x, return_attention=True) |
| layer_attentions.append(att) |
| else: |
| x = block(x) |
| |
| x = self.transformer.ln_f(x) |
| if self.use_out_proj: |
| logits = self.out_proj(x) |
| else: |
| logits = x |
| |
| if return_attention: |
| return logits, layer_attentions |
| return logits |
| |
| def forward_babel(self, idx: torch.Tensor, clip_feature: torch.Tensor, A_token_length) -> torch.Tensor: |
| if len(idx) == 0: |
| token_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(1) |
|
|
| else: |
| b, t, c = idx.size() |
| idx = idx.float() |
| assert ( |
| t <= self.config.block_size |
| ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
| |
|
|
| A_feature = clip_feature[:, 0, :] |
| B_feature = clip_feature[:, 1, :] |
|
|
|
|
| A_text_embeddings = self.transformer.cond_embed(A_feature).unsqueeze(1) |
| B_text_embeddings = self.transformer.cond_embed(B_feature).unsqueeze(1) |
|
|
| token_embeddings = torch.zeros(b, self.config.block_size, self.config.n_embd).to(idx.device) |
| for i in range(b): |
| A_idx = idx[i, :A_token_length[i].item(), :] |
| B_idx = idx[i, A_token_length[i].item():-2, :] |
| token_embeddings[i, :, :] = torch.cat([A_text_embeddings[i], self.BOM_tag, self.transformer.wte(A_idx), B_text_embeddings[i], self.BOM_tag, self.transformer.wte(B_idx)], dim=0) |
| |
| x = token_embeddings |
| for block in self.transformer.h: |
| x = block(x) |
| x = self.transformer.ln_f(x) |
|
|
| if self.use_out_proj: |
| logits = self.out_proj(x) |
| else: |
| logits = x |
|
|
|
|
| return logits |
|
|
| def forward_babel2(self, idx: torch.Tensor, clip_feature: torch.Tensor) -> torch.Tensor: |
| if len(idx) == 0: |
| token_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(1) |
|
|
| else: |
| b, t, c = idx.size() |
| idx = idx.float() |
| assert ( |
| t <= self.config.block_size |
| ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
| B_feature = clip_feature |
| B_text_embeddings = self.transformer.cond_embed(B_feature) |
|
|
| idx_embeddings = self.transformer.wte(idx) |
|
|
| |
| token_embeddings = torch.cat([B_text_embeddings, idx_embeddings], dim=1) |
| |
| |
| x = token_embeddings |
| for block in self.transformer.h: |
| x = block(x) |
| x = self.transformer.ln_f(x) |
|
|
| if self.use_out_proj: |
| logits = self.out_proj(x) |
| else: |
| logits = x |
|
|
| return logits |
| |
|
|
| def resize_token_embeddings( |
| self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, using_old_initilization: bool = False |
| ) -> nn.Embedding: |
| """ |
| Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. |
| |
| Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. |
| |
| Arguments: |
| new_num_tokens (`int`, *optional*): |
| The new number of tokens in the embedding matrix. Increasing the size will add newly initialized |
| vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just |
| returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. |
| pad_to_multiple_of (`int`, *optional*): |
| If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to |
| `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. |
| |
| This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability |
| `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more |
| details about this, or help on choosing the correct value for resizing, refer to this guide: |
| https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc |
| |
| Return: |
| `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. |
| """ |
| model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) |
| if new_num_tokens is None and pad_to_multiple_of is None: |
| return model_embeds |
|
|
| |
| self.config.vocab_size = model_embeds.weight.shape[0] |
| self.vocab_size = model_embeds.weight.shape[0] |
|
|
| |
| |
|
|
| return model_embeds |
| |
| def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): |
| old_embeddings = self.get_input_embeddings() |
| new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) |
| old_embeddings_requires_grad = old_embeddings.weight.requires_grad |
| new_embeddings.requires_grad_(old_embeddings_requires_grad) |
| self.set_input_embeddings(new_embeddings) |
|
|
| |
| if pad_to_multiple_of is not None: |
| |
| |
|
|
| |
| |
| |
| new_num_tokens = new_embeddings.weight.shape[0] |
|
|
| |
| |
| if self.get_output_embeddings() is not None and not False: |
| old_lm_head = self.get_output_embeddings() |
| new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens) |
| |
| |
| |
| old_lm_head_requires_grad = old_lm_head.weight.requires_grad |
| new_lm_head.requires_grad_(old_lm_head_requires_grad) |
| self.set_output_embeddings(new_lm_head) |
|
|
| return self.get_input_embeddings() |
| |
| def _get_resized_embeddings( |
| self, |
| old_embeddings: nn.Embedding, |
| new_num_tokens: Optional[int] = None, |
| pad_to_multiple_of: Optional[int] = None, |
| ) -> nn.Embedding: |
| """ |
| Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly |
| initialized vectors at the end. Reducing the size will remove vectors from the end |
| |
| Args: |
| old_embeddings (`torch.nn.Embedding`): |
| Old embeddings to be resized. |
| new_num_tokens (`int`, *optional*): |
| New number of tokens in the embedding matrix. |
| |
| Increasing the size will add newly initialized vectors at the end. Reducing the size will remove |
| vectors from the end. If not provided or `None`, just returns a pointer to the input tokens |
| `torch.nn.Embedding` module of the model without doing anything. |
| pad_to_multiple_of (`int`, *optional*): |
| If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to |
| `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. |
| |
| This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability |
| `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more |
| details about this, or help on choosing the correct value for resizing, refer to this guide: |
| https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc |
| |
| |
| Return: |
| `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if |
| `new_num_tokens` is `None` |
| """ |
|
|
| if pad_to_multiple_of is not None: |
| if not isinstance(pad_to_multiple_of, int): |
| raise ValueError( |
| f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer" |
| ) |
| if new_num_tokens is None: |
| new_num_tokens = old_embeddings.weight.shape[0] |
| new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of |
| else: |
| print( |
| "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding" |
| f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available." |
| " For more details about this, or help on choosing the correct value for resizing, refer to this guide:" |
| " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc" |
| ) |
|
|
| if new_num_tokens is None: |
| return old_embeddings |
|
|
| |
| if False: |
| import deepspeed |
|
|
| with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None): |
| old_num_tokens, old_embedding_dim = old_embeddings.weight.size() |
| else: |
| old_num_tokens, old_embedding_dim = old_embeddings.weight.size() |
|
|
| |
| if old_num_tokens == new_num_tokens and not False: |
| return old_embeddings |
|
|
| if not isinstance(old_embeddings, nn.Embedding): |
| raise TypeError( |
| f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You" |
| " should either use a different resize function or make sure that `old_embeddings` are an instance of" |
| f" {nn.Embedding}." |
| ) |
|
|
| |
|
|
| |
| |
| |
| |
| new_embeddings = nn.Embedding( |
| new_num_tokens, |
| old_embedding_dim, |
| device=old_embeddings.weight.device, |
| dtype=old_embeddings.weight.dtype, |
| ) |
|
|
| |
| self._init_weights(new_embeddings) |
|
|
| |
|
|
| |
| n = min(old_num_tokens, new_num_tokens) |
|
|
| |
| if False: |
| import deepspeed |
|
|
| params = [old_embeddings.weight, new_embeddings.weight] |
| with deepspeed.zero.GatheredParameters(params, modifier_rank=0): |
| new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] |
| else: |
| new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] |
|
|
| return new_embeddings |
|
|
|
|
| def _get_resized_lm_head( |
| self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False |
| ) -> nn.Linear: |
| """ |
| Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized |
| vectors at the end. Reducing the size will remove vectors from the end |
| |
| Args: |
| old_lm_head (`torch.nn.Linear`): |
| Old lm head liner layer to be resized. |
| new_num_tokens (`int`, *optional*): |
| New number of tokens in the linear matrix. |
| |
| Increasing the size will add newly initialized vectors at the end. Reducing the size will remove |
| vectors from the end. If not provided or `None`, just returns a pointer to the input tokens |
| `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults |
| to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim, |
| vocab_size` else `vocab_size, lm_head_dim`. |
| |
| Return: |
| `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is |
| `None` |
| """ |
| if new_num_tokens is None: |
| return old_lm_head |
|
|
| |
| if False: |
| import deepspeed |
|
|
| with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None): |
| old_num_tokens, old_lm_head_dim = ( |
| old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() |
| ) |
| else: |
| old_num_tokens, old_lm_head_dim = ( |
| old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() |
| ) |
|
|
| |
| if old_num_tokens == new_num_tokens and not False: |
| return old_lm_head |
|
|
| if not isinstance(old_lm_head, nn.Linear): |
| raise TypeError( |
| f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You" |
| " should either use a different resize function or make sure that `old_lm_head` are an instance of" |
| f" {nn.Linear}." |
| ) |
|
|
| |
| new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) |
| has_new_lm_head_bias = old_lm_head.bias is not None |
|
|
| |
| |
| |
| |
| new_lm_head = nn.Linear( |
| *new_lm_head_shape, |
| bias=has_new_lm_head_bias, |
| device=old_lm_head.weight.device, |
| dtype=old_lm_head.weight.dtype, |
| ) |
|
|
| |
| self._init_weights(new_lm_head) |
|
|
| num_tokens_to_copy = min(old_num_tokens, new_num_tokens) |
|
|
| |
| if False: |
| import deepspeed |
|
|
| params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias] |
| with deepspeed.zero.GatheredParameters(params, modifier_rank=0): |
| self._copy_lm_head_original_to_resized( |
| new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias |
| ) |
| else: |
| self._copy_lm_head_original_to_resized( |
| new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias |
| ) |
|
|
| return new_lm_head |
|
|
| def _copy_lm_head_original_to_resized( |
| self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias |
| ): |
| |
| if not transposed: |
| new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] |
| else: |
| new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy] |
|
|
| |
| if has_new_lm_head_bias: |
| new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] |
|
|
| @classmethod |
| def from_name(cls, name: str) -> Self: |
| return cls(LLaMAHFConfig.from_name(name)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, config: LLaMAHFConfig) -> None: |
| super().__init__() |
| self.rms_1 = RMSNorm(config.n_embd) |
|
|
| |
| self.attn = CausalSelfAttention(config) |
| self.rms_2 = RMSNorm(config.n_embd) |
| self.mlp = MLP(config) |
|
|
| def forward(self, x: torch.Tensor, last_past=None, use_cache=False, return_attention=False) -> torch.Tensor: |
| if use_cache: |
| if return_attention: |
| a, attn = self.attn.forward_attn(self.rms_1(x), last_past, use_cache) |
| else: |
| a, present = self.attn(self.rms_1(x), last_past, use_cache) |
| x = x + a |
| else: |
| if return_attention: |
| a, attn = self.attn.forward_attn(self.rms_1(x)) |
| else: |
| a = self.attn(self.rms_1(x)) |
| x = x + a |
| x = x + self.mlp(self.rms_2(x)) |
|
|
| if use_cache: |
| if return_attention: |
| return x, present, attn |
| else: |
| return x, present |
| else: |
| if return_attention: |
| return x, attn |
| else: |
| return x |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, config: LLaMAHFConfig) -> None: |
| super().__init__() |
| assert config.n_embd % config.n_head == 0 |
|
|
| |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) |
| |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) |
|
|
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.block_size = config.block_size |
| self.rope_cache = None |
|
|
| def scaling_factor(sequence_threshold): |
| return np.log2((sequence_threshold**2) - sequence_threshold) |
| scale_init = scaling_factor(self.block_size) |
| self.scale = nn.Parameter(torch.tensor(scale_init)) |
|
|
| def forward(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor: |
| B, T, C = x.size() |
|
|
| |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
|
|
| head_size = C // self.n_head |
| k = k.view(B, T, self.n_head, head_size).transpose(1, 2) |
| q = q.view(B, T, self.n_head, head_size).transpose(1, 2) |
| v = v.view(B, T, self.n_head, head_size).transpose(1, 2) |
|
|
| |
| if use_cache: |
| if last_past is not None: |
| past_key, past_value = last_past |
| k = torch.cat([past_key, k], dim=-2) |
| v = torch.cat([past_value, v], dim=-2) |
| |
| |
| |
|
|
| if use_cache: |
| present = (k, v) |
| else: |
| present = None |
|
|
| |
| q = F.normalize(q, p=2, dim=-1) |
| k = F.normalize(k, p=2, dim=-1) |
|
|
| if self.rope_cache is None: |
| |
| self.rope_cache = build_rope_cache( |
| seq_len=self.block_size, |
| n_elem=self.n_embd // self.n_head, |
| dtype=x.dtype, |
| device=x.device, |
| ) |
|
|
|
|
| q = apply_rope(q, self.rope_cache) |
| k = apply_rope(k, self.rope_cache) |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True, scale=self.scale.item()) |
| |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
| |
| y = self.c_proj(y) |
|
|
| |
| if use_cache: |
| return y, present |
| return y |
|
|
| def forward_attn(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor: |
| B, T, C = x.size() |
|
|
| |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
|
|
| head_size = C // self.n_head |
| k = k.view(B, T, self.n_head, head_size).transpose(1, 2) |
| q = q.view(B, T, self.n_head, head_size).transpose(1, 2) |
| v = v.view(B, T, self.n_head, head_size).transpose(1, 2) |
|
|
| |
| if use_cache: |
| if last_past is not None: |
| past_key, past_value = last_past |
| k = torch.cat([past_key, k], dim=-2) |
| v = torch.cat([past_value, v], dim=-2) |
| |
| |
| |
|
|
| if use_cache: |
| present = (k, v) |
| else: |
| present = None |
|
|
| |
| q = F.normalize(q, p=2, dim=-1) |
| k = F.normalize(k, p=2, dim=-1) |
|
|
| if self.rope_cache is None: |
| |
| self.rope_cache = build_rope_cache( |
| seq_len=self.block_size, |
| n_elem=self.n_embd // self.n_head, |
| dtype=x.dtype, |
| device=x.device, |
| ) |
|
|
|
|
| q = apply_rope(q, self.rope_cache) |
| k = apply_rope(k, self.rope_cache) |
|
|
|
|
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
| att = F.softmax(att, dim=-1) |
|
|
| |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True) |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
| |
| y = self.c_proj(y) |
| |
| return y, att |
|
|
| class LengthCausalSelfAttention(nn.Module): |
| def __init__(self, config: LLaMAHFConfig) -> None: |
| super().__init__() |
| assert config.n_embd % config.n_head == 0 |
|
|
| |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) |
| |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) |
|
|
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.block_size = config.block_size |
| self.rope_cache = None |
|
|
| def forward(self, x: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor: |
| B, T, C = x.size() |
|
|
| |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
|
|
| head_size = C // self.n_head |
| k = k.view(B, T, self.n_head, head_size).transpose(1, 2) |
| q = q.view(B, T, self.n_head, head_size).transpose(1, 2) |
| v = v.view(B, T, self.n_head, head_size).transpose(1, 2) |
|
|
| if self.rope_cache is None: |
| |
| self.rope_cache = build_rope_cache( |
| seq_len=self.block_size, |
| n_elem=self.n_embd // self.n_head, |
| dtype=x.dtype, |
| device=x.device, |
| ) |
|
|
|
|
| |
| |
| q = apply_rope(q, self.rope_cache) |
| k = apply_rope(k, self.rope_cache) |
| |
| attn_mask = torch.ones(T, T, dtype=torch.bool, device=x.device) |
| attn_mask = torch.tril(attn_mask) |
| attn_mask = attn_mask.unsqueeze(0).expand(B, -1, -1) |
|
|
| text_mask = y_mask.unsqueeze(2)*y_mask.unsqueeze(1) |
| text_mask = F.pad(text_mask, (0, T-y_mask.shape[1], 0, T-y_mask.shape[1]), mode='constant', value=0) |
| attn_mask = torch.logical_or(attn_mask, text_mask) |
|
|
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask.unsqueeze(1), dropout_p=0.0, is_causal=False) |
|
|
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
| |
| y = self.c_proj(y) |
|
|
| return y |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, config: LLaMAHFConfig) -> None: |
| super().__init__() |
| hidden_dim = 4 * config.n_embd |
| n_hidden = int(2 * hidden_dim / 3) |
| N = 256 |
| |
| n_hidden = ((n_hidden - 1) // N) * N + N |
|
|
| self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False) |
| self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False) |
| self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x = F.silu(self.c_fc1(x)) * self.c_fc2(x) |
| x = self.c_proj(x) |
| return x |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root Mean Square Layer Normalization. |
| |
| Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: |
| https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. |
| """ |
|
|
| def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: |
| super().__init__() |
| self.scale = nn.Parameter(torch.ones(size)) |
| self.eps = eps |
| self.dim = dim |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| |
| |
| norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) |
| x_normed = x * torch.rsqrt(norm_x + self.eps) |
| return self.scale * x_normed |
|
|
|
|
| def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor: |
| """Enhanced Transformer with Rotary Position Embedding. |
| |
| Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ |
| transformers/rope/__init__.py. MIT License: |
| https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. |
| """ |
| |
| theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) |
|
|
| |
| seq_idx = torch.arange(seq_len, dtype=dtype, device=device) |
|
|
| |
| idx_theta = torch.outer(seq_idx, theta) |
|
|
| |
| |
| dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8] |
| working_dtype = ( |
| torch.float32 if dtype in dtypes_requiring_casting else dtype |
| ) |
| complex_dtype = ( |
| torch.complex32 if dtype in dtypes_requiring_casting else torch.complex64 |
| ) |
| cache = torch.polar( |
| torch.ones_like(idx_theta).to(working_dtype), idx_theta.to(working_dtype) |
| ).to(complex_dtype) |
| return cache |
|
|
|
|
| def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: |
| x = x.transpose(1, 2) |
|
|
| |
| T = x.size(1) |
| rope_cache = rope_cache[:T] |
| |
| xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) |
| rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3)) |
| x_out = torch.view_as_real(xc * rope_cache).flatten(3) |
| return x_out.transpose(1, 2).type_as(x) |