| import json |
| import os |
| import pdb |
| from mmcv.cnn.bricks import padding |
| import torch |
| from torch import nn, einsum |
| from typing import Optional, Dict, Tuple |
| from .mae_vit import MAEViT |
| from .htsat import HTSAT_Swin_Transformer, create_htsat_model |
| from .LMdecoder import LMDecoder, LMDecoder_qlora |
| from .vision_transformer import VisionTransformer |
| from einops import rearrange, repeat |
| from einops_exts import rearrange_many |
| import inspect |
|
|
| from transformers.modeling_utils import PreTrainedModel |
| from .configuration_maelm import MAELMConfig |
|
|
| class ArgsHandler: |
| def __init__(self, module, funcname, fargs, fkargs): |
| self.fargs = list(fargs) |
| self.fkargs = fkargs |
| func = getattr(module, funcname) |
| fal_repr = f"{funcname}_argnames_list" |
| if (argns_list:=getattr(module, fal_repr, None)) is None: |
| self.func_sig = inspect.signature(func) |
| self.argnames_list = list(self.func_sig.parameters.keys()) |
| setattr(module, fal_repr, self.argnames_list) |
| else: |
| self.argnames_list = argns_list |
|
|
| def get_arg(self, arg_name): |
| if arg_name in self.fkargs: |
| arg = self.fkargs[arg_name] |
| else: |
| arg = self.fargs[self.argnames_list.index(arg_name)] |
| return arg |
|
|
| def set_arg(self, arg_name, arg_value): |
| if arg_name in self.fkargs: |
| self.fkargs[arg_name] = arg_value |
| else: |
| self.fargs[self.argnames_list.index(arg_name)] = arg_value |
|
|
| def return_all_args(self,): |
| return tuple(self.fargs), self.fkargs |
|
|
| class SquaredReLU(nn.Module): |
| """ squared ReLU activation function""" |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x): |
| return torch.pow(torch.relu(x), 2) |
|
|
| def FeedForward(dim, out_dim, mult=4, act='gelu'): |
| """ |
| lucidrains implementation, slightly modified with the act parameter. |
| """ |
|
|
| acts = dict( |
| gelu=nn.GELU, |
| sqrelu=SquaredReLU, |
| relu=nn.ReLU |
| ) |
|
|
| assert act in acts, f"act. can only be one of {acts.keys()}" |
|
|
| inner_dim = int(dim * mult) |
| return nn.Sequential( |
| nn.LayerNorm(dim), |
| nn.Linear(dim, inner_dim, bias=False), |
| acts[act](), |
| nn.Linear(inner_dim, out_dim, bias=False) |
| ) |
|
|
|
|
| class PerceiverAttentionLayer(nn.Module): |
| def __init__( |
| self, |
| *, |
| feat_dim, |
| latent_dim, |
| dim_head=64, |
| heads=8 |
| ): |
| super().__init__() |
| self.scale = dim_head ** -0.5 |
| self.heads = heads |
| self.dim_head = dim_head |
|
|
| inner_dim = dim_head * heads |
|
|
| |
| self.norm_media = nn.LayerNorm(feat_dim) |
| self.norm_latents = nn.LayerNorm(latent_dim) |
|
|
| self.to_q = nn.Linear(latent_dim, inner_dim, bias=False) |
| self.to_k = nn.Linear(feat_dim, inner_dim, bias=False) |
| self.to_v = nn.Linear(feat_dim, inner_dim, bias=False) |
| self.to_out = nn.Linear(inner_dim, latent_dim, bias=False) |
|
|
| def forward(self, features, latents): |
| """ |
| Latent vectors are cross-attending to the visual features x. |
| :param x: Tensor (n_batch, n_features, dim) |
| visual features |
| :param latents: Tensor (n_batch, n_latents, dim) |
| latent learnt vectors from which the queries are computed. |
| Actually the same, just replicated in n_batch and n_frames dimension. |
| :return: Tensor (n_batch, n_latents, dim) |
| """ |
| assert features.ndim == 3 |
| assert latents.ndim == 3 |
| assert features.shape[0] == latents.shape[0] |
| |
|
|
| n_heads = self.heads |
| n_batch, n_features, dim = features.shape |
| n_queries = latents.shape[1] |
|
|
| |
| x = self.norm_media(features) |
| latents = self.norm_latents(latents) |
|
|
| |
| |
| q = self.to_q(latents) |
| q = rearrange(q, 'b q (h d) -> b h q d', h=n_heads) |
| assert q.shape == torch.Size([n_batch, n_heads, n_queries, self.dim_head]) |
|
|
| |
| |
| ''' |
| kv_input = torch.cat((x, latents), dim=-2) |
| n_features_latents = n_features + n_queries |
| ''' |
|
|
| kv_input = x |
| n_features_latents = n_features |
|
|
| |
| k = self.to_k(kv_input) |
| v = self.to_v(kv_input) |
| |
|
|
| |
| |
| k, v = rearrange_many((k, v), 'b f (h d) -> b h f d', h=n_heads) |
| assert v.shape == torch.Size([n_batch, n_heads, n_features_latents, self.dim_head]) |
|
|
| |
| q = q * self.scale |
|
|
| |
|
|
| |
| |
| sim = einsum('b h q d, b h f d -> b h q f', q, k) |
|
|
| |
| sim = sim - sim.amax(dim=-1, keepdim=True).detach() |
| alphas = sim.softmax(dim=-1) |
|
|
| |
| out = einsum('b h q f, b h f v -> b h q v', alphas, v) |
|
|
| |
| out = rearrange(out, 'b h q v -> b q (h v)') |
| return self.to_out(out) |
|
|
|
|
| class MAEForCausalLM(PreTrainedModel): |
| """ |
| |
| Args: |
| backbone (dict): Config dict for encoder. Defaults to None. |
| neck (dict): Config dict for encoder. Defaults to None. |
| head (dict): Config dict for loss functions. Defaults to None. |
| init_cfg (dict, optional): Config dict for weight initialization. |
| Defaults to None. |
| """ |
| |
| config_class = MAELMConfig |
|
|
| def __init__(self, config: MAELMConfig) -> None: |
| super().__init__(config) |
| backbone = config.backbone |
| assert backbone is not None |
| bk_name = backbone.pop('name') |
| self.bk_name = bk_name |
| if bk_name == 'MAEViT': |
| ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None |
| self.backbone = MAEViT(**backbone) |
| |
| |
| |
| |
| elif bk_name == 'HTSAT': |
| ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None |
| self.backbone = create_htsat_model(backbone) |
| if ckpt_path is not None: |
| ckpt = torch.load( ckpt_path,'cpu') |
| self.backbone.load_state_dict(ckpt['state_dict']) |
| elif bk_name == 'qformer': |
| raise NotImplemented |
| else: |
| raise NotImplemented |
|
|
|
|
|
|
| |
| |
| neck = config.neck |
| assert neck is not None |
| nk_name = neck.pop('name') |
| if nk_name == 'LMDecoder': |
| self.neck = LMDecoder(**neck) |
| elif nk_name == 'LMDecoder_qlora': |
| self.neck = LMDecoder_qlora(**neck) |
| else: |
| raise NotImplemented |
| self.config = self.neck.LMconfig |
|
|
| ''' |
| self.ae_proj = nn.Linear( |
| 768, self.config.hidden_size |
| ) |
| ''' |
| |
| |
|
|
| |
| self.neck.lm.model.gradient_checkpointing = False |
|
|
| self.register_buffer('ones', torch.ones((1,4096), dtype=torch.long), persistent=False) |
| self.graft_adapter() |
| self.init_weights() |
| |
| for p in self.parameters(): |
| p.data = p.data.to(torch.bfloat16) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if False: |
| self.patch_llm() |
| self.first_run = True |
| |
| def graft_adapter(self): |
| adapter_latent_len = 32 |
| self.adapter_latent_len = adapter_latent_len |
| self.adapter_latent = nn.Parameter(torch.rand((1,adapter_latent_len, self.config.hidden_size), \ |
| dtype=torch.float)) |
| resampler_latent_len = 32 |
| self.resampler_latent_len = resampler_latent_len |
| self.resampler_latent = nn.Parameter(torch.rand((1,resampler_latent_len, self.config.hidden_size), \ |
| dtype=torch.float)) |
| |
| |
|
|
| self.adapter = nn.ModuleList([]) |
| |
| ff_mult = 4 |
| heads=8 |
| dim_head=512 |
| act='gelu' |
|
|
| lm_dim = self.config.hidden_size |
| if self.bk_name == 'HTSAT': |
| feat_dim = 1024 |
| depth = len(self.backbone.layers[2].blocks) |
| else: |
| feat_dim = 768 |
| depth = int(len(self.neck.lm.model.layers)/2) |
| for idx in range(depth): |
| self.adapter.append(nn.ModuleList([ |
| Adapter(input_size=self.config.hidden_size), |
| |
| |
| |
| ])) |
|
|
| self.samplers = nn.ModuleList([]) |
| for _ in range(3): |
| self.samplers.append(nn.ModuleList([ |
| PerceiverAttentionLayer(feat_dim=feat_dim, latent_dim=lm_dim, dim_head=64, heads=heads), |
| FeedForward(dim=lm_dim, out_dim=lm_dim, mult=4), |
| ])) |
| self.norm = nn.LayerNorm(lm_dim) |
|
|
| |
| |
| |
|
|
|
|
| |
| def init_weights(self): |
| try: |
| super().init_weights() |
| except: |
| pass |
| |
| |
| if getattr(self, 'adapter_latent', None) is not None: |
| self.adapter_latent.data.normal_(mean=0.0, std=0.02) |
| if getattr(self, 'resampler_latent', None) is not None: |
| self.adapter_latent.data.normal_(mean=0.0, std=0.02) |
|
|
| def forward_resampler(self, x): |
| |
| latents = repeat(self.resampler_latent, 'b n d -> (bs b) n d', bs=x.shape[0]) |
| for attn, ff in self.samplers: |
| latents = attn(x, latents) + latents |
| latents = ff(latents) + latents |
| v2t_feats = self.norm(latents) |
| |
| return v2t_feats |
|
|
|
|
| def hook_adapter(self, audio_embedding, lm, v2t_feats): |
| |
| class PHooker: |
| |
| |
| adapter = self.adapter |
| y = v2t_feats |
| handles_list = list() |
| cnter = 0 |
| def layer_prehook(self, m, margs, mkargs): |
| ahl = ArgsHandler(m, 'forward', margs, mkargs) |
| |
| |
| |
| |
| |
| |
| adapt = self.adapter[self.cnter][0] |
|
|
| hs = ahl.get_arg("hidden_states") |
| adapter_residual = hs |
| neo_hs = adapt(hs, adapter_residual) |
|
|
| self.cnter+=1 |
| ahl.set_arg("hidden_states", neo_hs) |
| return ahl.return_all_args() |
| def first_layer_prehook(self, m, margs, mkargs): |
| ahl = ArgsHandler(m, 'forward', margs, mkargs) |
| neo_lm_latents = self.y |
| hs = ahl.get_arg("hidden_states") |
| hs_msk = self.lm_ahl.get_arg("input_ids") < 0 |
| |
| neo_hs = hs.masked_scatter(hs_msk.unsqueeze(-1), neo_lm_latents) |
| ahl.set_arg("hidden_states", neo_hs) |
| return ahl.return_all_args() |
|
|
| def lm_prehook(self, m, margs, mkargs): |
| self.lm_ahl = ArgsHandler(m, 'forward', margs, mkargs) |
| return None |
| def last_layer_hook(self, m, margs, mkargs): |
| |
| self.cnter = 0 |
|
|
| if getattr(lm,'phooker',False): |
| for _ in lm.phooker.handles_list: |
| _.remove() |
| del lm.phooker |
| lm.phooker = None |
| phooker = PHooker() |
| phooker.handles_list.append(lm.register_forward_pre_hook(phooker.lm_prehook, with_kwargs=True)) |
| |
| phooker.handles_list.append(lm.model.layers[0].register_forward_pre_hook(phooker.first_layer_prehook, with_kwargs=True)) |
| |
| for ii in range(1,len(lm.model.layers),2): |
| l = lm.model.layers[ii] |
| handle = l.register_forward_pre_hook(phooker.layer_prehook, with_kwargs=True) |
| phooker.handles_list.append(handle) |
| phooker.handles_list.append(lm.model.layers[-1].register_forward_pre_hook(phooker.last_layer_hook, with_kwargs=True)) |
| lm.phooker = phooker |
| return None |
|
|
|
|
|
|
| def prepare_ids(self, batch, audio_ids): |
| toker = self.neck.tokenizer |
| |
| |
| |
| with torch.no_grad(): |
| |
| input_ids = batch['input_ids'] |
| att_msk = batch['attention_mask'] |
| au_crds = batch['audio_crds'] |
| ans_crds = batch['ans_crds'] |
| bsz = input_ids.shape[0] |
| |
| |
| merged_ids, merged_msk, label_ids = list(), list(), list() |
| for i in range(bsz): |
| |
| cur_merged_ids = torch.cat([ -1 * audio_ids[i] -1, input_ids[i,au_crds[i]:]]) |
| |
| |
| cur_au_msk = torch.ones(audio_ids.shape[1], device=audio_ids.device) |
| |
| cur_merged_msk = torch.cat([ cur_au_msk, att_msk[i,au_crds[i]:]]) |
| cur_label_ids = cur_merged_ids.clone().detach() |
| cur_label_ids[:audio_ids.shape[1]+ans_crds[i]] = -100 |
|
|
| merged_ids.append(cur_merged_ids) |
| merged_msk.append(cur_merged_msk) |
| label_ids.append(cur_label_ids) |
|
|
| merged_ids = torch.stack(merged_ids, dim=0) |
| merged_msk = torch.stack(merged_msk, dim=0) |
| label_ids = torch.stack(label_ids, dim=0) |
|
|
| assert merged_ids.shape[0] == bsz |
| assert merged_ids.shape == merged_msk.shape |
|
|
| label_msk = merged_msk.clone() |
| assert label_msk.shape == merged_msk.shape |
| assert merged_msk[:,-1].max() == 1 |
|
|
| for i in range(len(ans_crds)): |
| label_ids[i,:audio_ids.shape[1]+ans_crds[i]].fill_(-100) |
| |
| |
| merged_labels = label_ids |
| merged_ids[merged_ids.eq(-100)] = toker.pad_token_id |
|
|
| return merged_ids, merged_msk, merged_labels |
|
|
| def forward(self, batch, **kwargs): |
| """Forward computation during training. |
| |
| Args: |
| img (torch.Tensor): Input images of shape (N, C, H, W). |
| kwargs: Any keyword arguments to be used to forward. |
| Returns: |
| Dict[str, torch.Tensor]: A dictionary of loss components. |
| """ |
| bsz = len(batch['input_ids']) |
| device = batch['input_ids'].device |
| float_type = next(self.parameters()).dtype |
| spectrogram = batch['spectrogram'].type(float_type) |
| audio_embedding = self.backbone(spectrogram).detach() |
| resampler_feats = self.forward_resampler(audio_embedding) |
| self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) |
| |
| |
| |
| audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device) |
| assert audio_ids.max() < 100 |
| merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids) |
| |
| try: |
| assert merged_ids.shape == merged_labels.shape |
| outs = self.neck(input_ids=merged_ids.contiguous().long(), |
| flatten_embs=self.adapter_latent.flatten(0,1), |
| |
| attention_mask=merged_msk.contiguous().long(), |
| labels=merged_labels.contiguous().long(), use_cache=False) |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| __import__('remote_pdb').set_trace() |
| |
|
|
| |
| if eval(os.environ.get("doing_eval", 'False')): |
| outs.merged_ids = merged_ids.cpu() |
| outs.merged_labels = merged_labels.cpu() |
|
|
| return outs |
|
|
|
|
| def forward_test(self, batch, **kwargs): |
| """Forward computation during training. |
| |
| Args: |
| img (torch.Tensor): Input images of shape (N, C, H, W). |
| kwargs: Any keyword arguments to be used to forward. |
| Returns: |
| Dict[str, torch.Tensor]: A dictionary of loss components. |
| """ |
|
|
|
|
| bsz = len(batch['input_ids']) |
| device = batch['input_ids'].device |
| float_type = next(self.parameters()).dtype |
| spectrogram = batch['spectrogram'].type(float_type) |
| audio_embedding = self.backbone(spectrogram).detach() |
| resampler_feats = self.forward_resampler(audio_embedding) |
| self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) |
| |
| audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device) |
| assert audio_ids.max() < 100 |
|
|
| merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids) |
| au_crds = batch['audio_crds'] |
| ans_crds = batch['ans_crds'] |
| |
| aid_len = audio_ids.shape[-1] |
| |
|
|
| toker = self.neck.tokenizer |
| with torch.no_grad(): |
|
|
| |
| pad_token = toker.encode(self.neck.tokenizer.eos_token)[0] |
| padded_merged_ids = self.ones[:, :aid_len+max(ans_crds)].repeat(bsz, 1).clone().detach() * pad_token |
| for i in range(bsz): |
| |
| assert au_crds[i] <= ans_crds[i] |
| cur_ids = merged_ids[i][:aid_len+ans_crds[i]] |
| padded_merged_ids[i][max(ans_crds)-ans_crds[i]:] = cur_ids |
| |
| outs = self.neck.generate(padded_merged_ids, self.adapter_latent.flatten(0,1)) |
| |
|
|
| return outs |
|
|
|
|
|
|
| import torch |
| from torch import nn |
|
|
| from transformers.activations import ACT2FN |
|
|
| class Adapter(nn.Module): |
| """ |
| Implementation of a sequential bottleneck adapter block. |
| """ |
| def __init__( |
| self, |
| input_size, |
| down_sample=None, |
| ): |
| super().__init__() |
|
|
| self.input_size = input_size |
|
|
| |
| self.down_sample = down_sample |
| if down_sample is None: |
| self.down_sample = self.input_size // 2 |
|
|
| self.adapter_norm_before = nn.LayerNorm(self.input_size) |
| self.adapter_down = nn.Linear(self.input_size, self.down_sample) |
| self.non_linearity = ACT2FN["silu"] |
|
|
| |
| self.adapter_up = nn.Linear(self.down_sample, self.input_size) |
|
|
| |
| self.scaling = nn.Parameter(torch.ones(1)) |
|
|
| self.adapter_down.apply(self._init_weights) |
| self.adapter_up.apply(self._init_weights) |
|
|
| def forward(self, x, residual_input): |
|
|
| down = self.non_linearity(self.adapter_down(self.adapter_norm_before(x))) |
|
|
| up = self.adapter_up(down) |
| up = up * self.scaling |
| output = up |
|
|
| output = output + residual_input |
|
|
| return output |
|
|
| @staticmethod |
| def _init_weights(module): |
| """Initialize the weights.""" |
| if isinstance(module, (nn.Linear, nn.Embedding)): |
| |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
| if isinstance(module, nn.Linear) and module.bias is not None: |
| module.bias.data.zero_() |
|
|
|
|