| ''' |
| Code Reference: |
| |
| * https://github.com/jadore801120/attention-is-all-you-need-pytorch/ |
| * https://github.com/GT-RIPL/CODA-Prompt |
| * https://github.com/openai/CLIP |
| ''' |
|
|
| import os |
| import math |
| import torch |
| import numpy as np |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from functools import partial |
| from collections import Counter |
| from timm.models.vision_transformer import PatchEmbed |
| from timm.models.layers import trunc_normal_, DropPath |
| from scipy.special import softmax |
|
|
| from .petl.adapter import Adapter, MaskedAdapter |
| from .petl.proj import Proj |
| from .prompt import L2P |
|
|
| |
| class SparseDispatcher(object): |
| """Helper for implementing a mixture of experts. |
| The purpose of this class is to create input minibatches for the |
| experts and to combine the results of the experts to form a unified |
| output tensor. |
| There are two functions: |
| dispatch - take an input Tensor and create input Tensors for each expert. |
| combine - take output Tensors from each expert and form a combined output |
| Tensor. Outputs from different experts for the same batch element are |
| summed together, weighted by the provided "gates". |
| The class is initialized with a "gates" Tensor, which specifies which |
| batch elements go to which experts, and the weights to use when combining |
| the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. |
| The inputs and outputs are all two-dimensional [batch, depth]. |
| Caller is responsible for collapsing additional dimensions prior to |
| calling this class and reshaping the output to the original shape. |
| See common_layers.reshape_like(). |
| Example use: |
| gates: a float32 `Tensor` with shape `[batch_size, num_experts]` |
| inputs: a float32 `Tensor` with shape `[batch_size, input_size]` |
| experts: a list of length `num_experts` containing sub-networks. |
| dispatcher = SparseDispatcher(num_experts, gates) |
| expert_inputs = dispatcher.dispatch(inputs) |
| expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] |
| outputs = dispatcher.combine(expert_outputs) |
| The preceding code sets the output for a particular example b to: |
| output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) |
| This class takes advantage of sparsity in the gate matrix by including in the |
| `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. |
| """ |
|
|
| def __init__(self, num_experts, gates): |
| """Create a SparseDispatcher.""" |
|
|
| self._gates = gates |
| self._num_experts = num_experts |
|
|
| sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) |
|
|
| |
| _, self._expert_index = sorted_experts.split(1, dim=1) |
| |
| self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0] |
| |
| self._part_sizes = (gates > 0).sum(0).tolist() |
| |
| gates_exp = gates[self._batch_index.flatten()] |
| self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) |
|
|
| def dispatch(self, inp): |
| """Create one input Tensor for each expert. |
| The `Tensor` for a expert `i` contains the slices of `inp` corresponding |
| to the batch elements `b` where `gates[b, i] > 0`. |
| Args: |
| inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]` |
| Returns: |
| a list of `num_experts` `Tensor`s with shapes |
| `[expert_batch_size_i, <extra_input_dims>]`. |
| """ |
|
|
| |
|
|
| inp_exp = inp[self._batch_index].squeeze(1) |
| return torch.split(inp_exp, self._part_sizes, dim=0) |
|
|
| def combine(self, expert_out, multiply_by_gates=True): |
| """Sum together the expert output, weighted by the gates. |
| The slice corresponding to a particular batch element `b` is computed |
| as the sum over all experts `i` of the expert output, weighted by the |
| corresponding gate values. If `multiply_by_gates` is set to False, the |
| gate values are ignored. |
| Args: |
| expert_out: a list of `num_experts` `Tensor`s, each with shape |
| `[expert_batch_size_i, <extra_output_dims>]`. |
| multiply_by_gates: a boolean |
| Returns: |
| a `Tensor` with shape `[batch_size, <extra_output_dims>]`. |
| """ |
| |
|
|
| stitched = torch.cat(expert_out, 0) |
| if multiply_by_gates: |
| stitched = stitched.mul(self._nonzero_gates) |
|
|
| zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), device=stitched.device) |
| |
|
|
| combined = zeros.index_add(0, self._batch_index, stitched.float()) |
| |
| |
| return combined |
|
|
| def expert_to_gates(self): |
| """Gate values corresponding to the examples in the per-expert `Tensor`s. |
| Returns: |
| a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` |
| and shapes `[expert_batch_size_i]` |
| """ |
| |
| return torch.split(self._nonzero_gates, self._part_sizes, dim=0) |
|
|
| |
| class LayerNorm(nn.LayerNorm): |
| """Subclass torch's LayerNorm to handle fp16.""" |
|
|
| def forward(self, x: torch.Tensor): |
| orig_type = x.dtype |
| ret = super().forward(x.type(torch.float32)) |
| return ret.type(orig_type) |
|
|
| class QuickGELU(nn.Module): |
| def forward(self, x: torch.Tensor): |
| return x * torch.sigmoid(1.702 * x) |
|
|
| |
| class MultiHeadAttention(nn.Module): |
| def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| |
| self.scale = qk_scale or head_dim ** -0.5 |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0. else nn.Identity() |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0. else nn.Identity() |
| self.attn_gradients = None |
| self.attention_map = None |
|
|
| def save_attn_gradients(self, attn_gradients): |
| self.attn_gradients = attn_gradients |
| |
| def get_attn_gradients(self): |
| return self.attn_gradients |
| |
| def save_attention_map(self, attention_map): |
| self.attention_map = attention_map |
| |
| def get_attention_map(self): |
| return self.attention_map |
| |
| def forward(self, x, attn_mask=None, register_hook=False, prompt=None): |
|
|
| B, N, C = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| if prompt is not None: |
| pk, pv = prompt |
| pk = pk.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) |
| pv = pv.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) |
| k = torch.cat((pk,k), dim=2) |
| v = torch.cat((pv,v), dim=2) |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if attn_mask is not None: |
| attn += attn_mask.unsqueeze(0) |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
| return x |
|
|
| class MultiHeadAttention_LoRA(MultiHeadAttention): |
|
|
| ''' |
| Attention module with lora, apply to k, v |
| ''' |
|
|
| def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) |
|
|
| self.lora_rank = lora_rank |
| |
| self.lora_A_k = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) |
| self.lora_B_k = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) |
| self.lora_A_v = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) |
| self.lora_B_v = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) |
| self.apply_lora = False |
|
|
| self.cur_matrix = torch.zeros(self.dim ,self.dim) |
| self.n_cur_matrix = 0 |
|
|
| def init_param(self): |
|
|
| nn.init.kaiming_uniform_(self.lora_A_k.weight, a=math.sqrt(5)) |
| nn.init.kaiming_uniform_(self.lora_A_v.weight, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B_k.weight) |
| nn.init.zeros_(self.lora_B_v.weight) |
|
|
| self.apply_lora = True |
|
|
| def merge_weight(self): |
| |
| q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
| k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
| self.qkv.weight.data = torch.cat([q_weight, k_weight, v_weight], dim=0) |
| self.apply_lora = False |
|
|
| def reset_input_matrix(self): |
| self.cur_matrix.zero_() |
| self.n_cur_matrix = 0 |
|
|
| def forward(self, x, attn_mask=None, register_hook=False, prompt=None, get_input_matrix = False): |
| |
| if get_input_matrix: |
| self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0] * x.shape[1]) |
| self.n_cur_matrix += x.shape[0]*x.shape[1] |
|
|
| B, N, C = x.shape |
|
|
| q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
|
|
| if self.apply_lora: |
| k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
| |
| qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if attn_mask is not None: |
| attn += attn_mask.unsqueeze(0) |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| return x |
|
|
| class MultiHeadAttention_SDLoRA(MultiHeadAttention): |
| def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) |
|
|
| self.lora_rank = lora_rank |
| self.lora_bias = lora_bias |
| |
| self.lora_A_q_list = nn.ModuleList([]) |
| self.lora_B_q_list = nn.ModuleList([]) |
| self.lora_A_v_list = nn.ModuleList([]) |
| self.lora_B_v_list = nn.ModuleList([]) |
|
|
| self.assimilated_mag_lora_q = [] |
| self.assimilated_mag_lora_v = [] |
|
|
| def init_param(self): |
|
|
| self.lora_A_q_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias)) |
| self.lora_B_q_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias)) |
| self.lora_A_v_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias)) |
| self.lora_B_v_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias)) |
|
|
| nn.init.kaiming_uniform_(self.lora_A_q_list[-1].weight, a=math.sqrt(5)) |
| nn.init.kaiming_uniform_(self.lora_A_v_list[-1].weight, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B_q_list[-1].weight) |
| nn.init.zeros_(self.lora_B_v_list[-1].weight) |
|
|
| self.assimilated_mag_lora_q.append( |
| torch.Tensor([0.0]).to(self.qkv.weight.device) |
| ) |
| self.assimilated_mag_lora_v.append( |
| torch.Tensor([0.0]).to(self.qkv.weight.device) |
| ) |
|
|
| assert len(self.lora_A_q_list) == len(self.mag_lora) |
| assert len(self.mag_lora) == len(self.assimilated_mag_lora_q) |
|
|
| def forward(self, x, attn_mask=None, register_hook=False, prompt=None): |
| |
| B, N, C = x.shape |
|
|
| qq = self.mag_lora[-1] * self.lora_B_q_list[-1](self.lora_A_q_list[-1](x)) |
| vv = self.mag_lora[-1] * self.lora_B_v_list[-1](self.lora_A_v_list[-1](x)) |
|
|
| for i in range(len(self.lora_A_q_list) - 1): |
|
|
| norm_B = torch.norm(self.lora_B_q_list[i].weight) |
| norm_A = torch.norm(self.lora_A_q_list[i].weight) |
| |
| if norm_B != 0 and norm_A != 0: |
| qq += (self.mag_lora[i] + self.assimilated_mag_lora_q[i]) * self.lora_B_q_list[i](self.lora_A_q_list[i](x)) / (norm_B * norm_A) |
|
|
| norm_B = torch.norm(self.lora_B_v_list[i].weight) |
| norm_A = torch.norm(self.lora_A_v_list[i].weight) |
|
|
| if norm_B != 0 and norm_A != 0: |
| vv += (self.mag_lora[i] + self.assimilated_mag_lora_v[i]) * self.lora_B_v_list[i](self.lora_A_v_list[i](x)) / (norm_B * norm_A) |
|
|
| qkv = self.qkv(x) |
| qkv[:, :, : self.dim] += qq |
| qkv[:, :, -self.dim :] += vv |
|
|
| qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if attn_mask is not None: |
| attn += attn_mask.unsqueeze(0) |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| return x |
|
|
| class MultiHeadAttention_LoRA_Sub(MultiHeadAttention): |
|
|
| ''' |
| Attention module with lora, apply to k, v |
| ''' |
|
|
| def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) |
|
|
| self.lora_rank = lora_rank |
| |
| self.lora_A_k = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) |
| self.lora_B_k = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) |
| self.lora_A_v = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) |
| self.lora_B_v = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) |
| self.apply_lora = False |
|
|
| self.cur_matrix = torch.zeros(self.dim ,self.dim) |
| self.n_cur_matrix = 0 |
|
|
| self.register_buffer("prev_k_weight", torch.zeros(self.dim, self.dim)) |
| self.register_buffer("prev_v_weight", torch.zeros(self.dim, self.dim)) |
|
|
| def init_param(self): |
|
|
| nn.init.kaiming_uniform_(self.lora_A_k.weight, a=math.sqrt(5)) |
| nn.init.kaiming_uniform_(self.lora_A_v.weight, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B_k.weight) |
| nn.init.zeros_(self.lora_B_v.weight) |
|
|
| self.apply_lora = True |
|
|
| def save_weight(self): |
|
|
| self.prev_k_weight += self.lora_B_k.weight @ self.lora_A_k.weight |
| self.prev_v_weight += self.lora_B_v.weight @ self.lora_A_v.weight |
| self.apply_lora = False |
|
|
| def reset_input_matrix(self): |
| self.cur_matrix.zero_() |
| self.n_cur_matrix = 0 |
|
|
| def forward(self, x, attn_mask=None, register_hook=False, prompt=None, get_input_matrix = False): |
| |
| B, N, C = x.shape |
|
|
| q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
|
|
| if get_input_matrix: |
| |
| self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0] * x.shape[1]) |
| self.n_cur_matrix += x.shape[0]*x.shape[1] |
|
|
| k_weight = k_weight - self.prev_k_weight |
| v_weight = v_weight - self.prev_v_weight |
|
|
| elif self.apply_lora: |
| |
| k_weight = k_weight + self.prev_k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| v_weight = v_weight + self.prev_v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
| else: |
| |
| k_weight = k_weight + self.prev_k_weight |
| v_weight = v_weight + self.prev_v_weight |
|
|
| qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if attn_mask is not None: |
| attn += attn_mask.unsqueeze(0) |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| return x |
|
|
| class MultiHeadAttention_CL_LoRA(MultiHeadAttention_LoRA): |
|
|
| ''' |
| Attention module with lora, apply to q, v |
| ''' |
|
|
| def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| |
| del self.lora_A_k |
| del self.lora_B_k |
| self.lora_A_q = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) |
| self.lora_B_q = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) |
|
|
| def init_param(self): |
| |
| q1, _ = torch.linalg.qr(torch.rand(self.dim, self.lora_rank)) |
| q2, _ = torch.linalg.qr(torch.rand(self.dim, self.lora_rank)) |
| with torch.no_grad(): |
| self.lora_A_q.weight.copy_(q1.T) |
| self.lora_A_v.weight.copy_(q2.T) |
|
|
| scaling_factor = 1. |
| self.lora_A_q.weight.data *= scaling_factor |
| self.lora_A_v.weight.data *= scaling_factor |
|
|
| nn.init.zeros_(self.lora_B_q.weight) |
| nn.init.zeros_(self.lora_B_v.weight) |
|
|
| def forward( |
| self, |
| x, |
| adapt=None, |
| prompt=None, |
| rank_prompt=None, |
| block_weight=None, |
| attn_mask=None, |
| register_hook=False): |
| |
| |
| |
| |
|
|
| B, N, C = x.shape |
|
|
| q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
|
|
| qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data) |
| |
| if adapt is not None: |
| if block_weight is not None: |
| block_weight = block_weight |
| else: |
| block_weight = torch.ones(3).to(x.device) |
| qq = block_weight[0] * adapt[0](x) |
| vv = block_weight[2] * adapt[2](x) |
|
|
| qkv[:, :, : self.dim] += qq |
| qkv[:, :, -self.dim :] += vv |
|
|
| qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
|
|
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if attn_mask is not None: |
| attn += attn_mask.unsqueeze(0) |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| return x |
|
|
| |
| class MultiHeadAttention_MaskedLoRA(MultiHeadAttention_LoRA): |
|
|
| |
| |
| def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
|
|
| |
| self.identity_matrix = torch.eye(self.qkv.weight.shape[1]) |
| |
| self.space = [[0, 0] for _ in range(10)] |
| self.scale_param = nn.ModuleList([nn.ParameterList([nn.Parameter(self.identity_matrix) for _ in range(2)]) for _ in range(10)]) |
| self.scaling_mask = [[False, False] for _ in range(10)] |
|
|
| def enable_scale(self, task_id, space): |
| if len(space) == 2: |
| self.space[task_id][0] = space[0] |
| self.space[task_id][1] = space[1] |
| self.scaling_mask[task_id][0] = True |
| self.scaling_mask[task_id][1] = True |
| elif len(space) == 1: |
| self.space[task_id][0] = space[0] |
| self.scaling_mask[task_id][0] = True |
|
|
| for scale_param_list in self.scale_param: |
| for scale_param in scale_param_list: |
| scale_param = scale_param.to(self.qkv.weight.device) |
|
|
| def forward(self, x, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix = False): |
|
|
| if get_input_matrix: |
| self.cur_matrix = (self.cur_matrix*self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0]*x.shape[1]) |
| self.n_cur_matrix += x.shape[0]*x.shape[1] |
| |
| B, N, C = x.shape |
|
|
| q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
|
|
| if self.apply_lora: |
| k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
| |
| for mask, scale, space in zip(self.scaling_mask[expert_id], self.scale_param[expert_id], self.space[expert_id]): |
|
|
| if not mask: |
| break |
| |
| scale_size = space.shape[1] |
| cropped_scale = scale[:scale_size, :scale_size] |
|
|
| cropped_scale = cropped_scale @ cropped_scale.T |
|
|
| cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(self.qkv.weight.device) |
|
|
| k_weight = k_weight + k_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| v_weight = v_weight + v_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
|
|
| qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if attn_mask is not None: |
| attn += attn_mask.unsqueeze(0) |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
| return x |
|
|
| |
| class MultiHeadAttention_MaskedLoRA1(MultiHeadAttention): |
| def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) |
|
|
| self.cur_task = -1 |
| self.lora_rank = lora_rank |
| |
| self.cur_matrix = torch.zeros(self.dim ,self.dim) |
| self.n_cur_matrix = 0 |
|
|
| self.lora_bias = lora_bias |
|
|
| self.lora_A_k_list = nn.ModuleList([]) |
| self.lora_B_k_list = nn.ModuleList([]) |
| self.lora_A_v_list = nn.ModuleList([]) |
| self.lora_B_v_list = nn.ModuleList([]) |
|
|
| self.space_k = [0 for _ in range(10)] |
| self.space_v = [0 for _ in range(10)] |
| self.identity_matrix = torch.eye(self.qkv.weight.shape[1]) |
| self.scale_param = nn.ParameterList([]) |
|
|
| def init_param(self): |
|
|
| self.lora_A_k_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias)) |
| self.lora_B_k_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias)) |
| self.lora_A_v_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias)) |
| self.lora_B_v_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias)) |
| self.scale_param.append(nn.Parameter(self.identity_matrix)) |
|
|
| nn.init.kaiming_uniform_(self.lora_A_k_list[-1].weight, a=math.sqrt(5)) |
| nn.init.kaiming_uniform_(self.lora_A_v_list[-1].weight, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B_k_list[-1].weight) |
| nn.init.zeros_(self.lora_B_v_list[-1].weight) |
|
|
| self.cur_task += 1 |
|
|
| def reset_input_matrix(self): |
| self.cur_matrixs = [] |
|
|
| def forward(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False): |
| |
| if get_input_matrix: |
| assert x.shape[0] < 512 |
| self.cur_matrixs.append(x.detach()) |
|
|
| if x.shape[0] > 128: |
| |
| activation = torch.bmm(x.permute(0, 2, 1), x).sum(dim=0) / x.shape[0] |
|
|
| |
| activation = self.lora_A_k_list[-1].weight.data.T @ self.lora_A_k_list[-1].weight.data @ activation |
|
|
| if self.cur_task > 0: |
| activation = activation - self.feature_mat @ activation |
|
|
| U, _, _ = torch.linalg.svd(activation, full_matrices = False) |
| A_new = U[:,:self.lora_rank].T / math.sqrt(3) |
| A_old = self.lora_A_k_list[-1].weight.data |
| Bk_old = self.lora_B_k_list[-1].weight.data |
| Bv_old = self.lora_B_v_list[-1].weight.data |
|
|
| tmp = A_old @ torch.pinverse(A_new) |
| Bk_new = Bk_old @ tmp |
| Bv_new = Bv_old @ tmp |
|
|
| ''' |
| # Compute matmul results |
| Bk_old_A_old = Bk_old @ A_old |
| Bk_new_A_new = Bk_new @ A_new |
| Bv_old_A_old = Bv_old @ A_old |
| Bv_new_A_new = Bv_new @ A_new |
| |
| # Compute the Frobenius norm of the difference between old and new matmul results |
| frobenius_norm_Bk = torch.norm(Bk_old_A_old - Bk_new_A_new, p='fro') |
| frobenius_norm_Bv = torch.norm(Bv_old_A_old - Bv_new_A_new, p='fro') |
| |
| # Printing the results |
| print(f"Frobenius norm difference between Bk_old @ A_old and Bk_new @ A_new: {frobenius_norm_Bk.item()}") |
| print(f"Frobenius norm difference between Bv_old @ A_old and Bv_new @ A_new: {frobenius_norm_Bv.item()}") |
| ''' |
|
|
| self.lora_A_k_list[-1].weight.data.copy_(A_new) |
| self.lora_A_v_list[-1].weight.data.copy_(A_new) |
| self.lora_B_k_list[-1].weight.data.copy_(Bk_new) |
| self.lora_B_v_list[-1].weight.data.copy_(Bv_new) |
|
|
| B, N, C = x.shape |
| q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
|
|
| for ii in range(self.cur_task): |
| k_weight = k_weight + self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight |
| v_weight = v_weight + self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight |
|
|
| k_weight = k_weight + self.lora_B_k_list[-1].weight @ self.lora_A_k_list[-1].weight |
| v_weight = v_weight + self.lora_B_v_list[-1].weight @ self.lora_A_v_list[-1].weight |
|
|
| ''' |
| for ii in range(self.cur_task): |
| if not isinstance(self.space_k[ii], int): |
| |
| space_k = self.space_k[ii] |
| space_v = self.space_v[ii] |
| scale_k = self.scale_param[ii] |
| |
| # Q Scaling |
| scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] |
| |
| # QQ^T Scaling |
| scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T |
| |
| # QQ^T Diagonal Scaling12 |
| #scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T)) |
| |
| # Q Diagonal Scaling |
| #scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]])) |
| |
| #scalee = scale_k[0, 0] |
| scalee = self.mag_lora[ii] |
| |
| use_scale = False |
| if use_scale: |
| |
| norm_B = torch.norm(self.lora_B_k_list[ii].weight) |
| norm_A = torch.norm(self.lora_A_k_list[ii].weight) |
| |
| k_weight = k_weight - self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight @ space_k.T @ space_k |
| k_weight = k_weight + scalee * (self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight @ space_k.T @ space_k) / (norm_B * norm_A) |
| |
| norm_B = torch.norm(self.lora_B_v_list[ii].weight) |
| norm_A = torch.norm(self.lora_A_v_list[ii].weight) |
| |
| v_weight = v_weight - self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight @ space_v.T @ space_v |
| v_weight = v_weight + scalee * (self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight @ space_v.T @ space_v) / (norm_B * norm_A) |
| ''' |
|
|
| qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if attn_mask is not None: |
| attn += attn_mask.unsqueeze(0) |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| return x, x, probs |
|
|
| |
| class MultiHeadAttention_MultiMaskedLoRA(MultiHeadAttention_MaskedLoRA): |
| def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
|
|
| self.activated_expert = 0 |
| self.saved_space = [[torch.tensor((1)), torch.tensor((1))] for _ in range(10)] |
|
|
| self.hit = 0 |
| self.total = 0 |
| self.projected_cur_matrix = torch.zeros(self.dim ,self.dim) |
| self.n_projected_cur_matrix = 0 |
|
|
| def reset_input_matrix(self): |
| super().reset_input_matrix() |
| self.projected_cur_matrix.zero_() |
| self.n_projected_cur_matrix = 0 |
|
|
| def enable_scale(self, task_id, space): |
| |
| if len(space) == 2: |
| self.space[task_id][0] = space[0] |
| self.space[task_id][1] = space[1] |
| self.scaling_mask[task_id][0] = True |
| self.scaling_mask[task_id][1] = True |
| elif len(space) == 1: |
| self.space[task_id][0] = space[0] |
| self.scaling_mask[task_id][0] = True |
|
|
| for scale_param_list in self.scale_param: |
| for scale_param in scale_param_list: |
| scale_param = scale_param.to(self.qkv.weight.device) |
|
|
| def save_space(self, task_id, space): |
| self.activated_expert = task_id |
| self.saved_space[task_id][0] = space |
|
|
| def forward(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False): |
| |
| B, N, C = x.shape |
|
|
| if get_input_matrix: |
| assert expert_id == 0 |
| self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + B * N) |
| self.n_cur_matrix += B * N |
|
|
| |
| if not self.training and not get_input_matrix: |
| with torch.no_grad(): |
|
|
| cur_cur_matrix = torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0) / (B * N) |
| saved = torch.stack([self.saved_space[idd][0] for idd in range(self.activated_expert + 1)]).to(x.device) |
| |
|
|
| proj_mat = saved.transpose(1, 2) |
| proj_mat = torch.einsum('ijk,kl->ijl', proj_mat, cur_cur_matrix) |
| |
| proj_norm = np.linalg.norm(proj_mat.cpu(), axis=(1, 2)) |
| |
| proj_norm = softmax(proj_norm) |
| probs.append(proj_norm) |
| selected_expert_id = np.argmax(proj_norm, axis = 0) |
| |
| expert_id = selected_expert_id |
|
|
| q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
|
|
| if self.apply_lora: |
| k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
| |
| qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if attn_mask is not None: |
| attn += attn_mask.unsqueeze(0) |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| |
|
|
| for mask, scale, space in zip(self.scaling_mask[expert_id], self.scale_param[expert_id], self.space[expert_id]): |
|
|
| if not mask: |
| break |
|
|
| scale_size = space.shape[1] |
| cropped_scale = scale[:scale_size, :scale_size] |
|
|
| cropped_scale = cropped_scale @ cropped_scale.T |
|
|
| cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(self.qkv.weight.device) |
|
|
| k_weight = k_weight + k_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| v_weight = v_weight + v_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| |
| qkv = F.linear(x_proj, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if attn_mask is not None: |
| attn += attn_mask.unsqueeze(0) |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x_proj = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x_proj = self.proj(x_proj) |
| x_proj = self.proj_drop(x_proj) |
|
|
| return x, x_proj, probs |
|
|
| def forward1(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False): |
| |
| B, N, C = x.shape |
|
|
| if get_input_matrix: |
| assert expert_id == 0 |
| self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + B * N) |
| self.n_cur_matrix += B * N |
| |
| |
| if not self.training and not get_input_matrix: |
| with torch.no_grad(): |
|
|
| cur_cur_matrix = torch.bmm(x.detach().permute(0, 2, 1), x.detach()) / N |
| cur_cur_matrix = cur_cur_matrix.permute(1, 2, 0) |
| saved = torch.stack([self.saved_space[idd][0] for idd in range(self.activated_expert + 1)]).to(x.device) |
| proj_mat = saved.transpose(1, 2) |
|
|
| proj_mat = torch.einsum('ijk,klm->ijlm', proj_mat, cur_cur_matrix) |
|
|
| proj_norm = np.linalg.norm(proj_mat, axis=(1, 2)) |
| proj_norm = softmax(proj_norm, axis=0) |
| |
| probs.append(proj_norm) |
|
|
| selected_expert_id = np.argmax(proj_norm, axis = 0) |
| selected_expert_id = torch.tensor(selected_expert_id).to(x.device) |
|
|
|
|
| q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
|
|
| if self.apply_lora: |
| k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
| |
| qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if attn_mask is not None: |
| attn += attn_mask.unsqueeze(0) |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| |
| if not self.training and not get_input_matrix: |
| inputs = [x_proj.clone() for _ in range(self.activated_expert + 1)] |
| k_weights = [k_weight.clone() for _ in range(self.activated_expert + 1)] |
| v_weights = [v_weight.clone() for _ in range(self.activated_expert + 1)] |
| qkv_outputs = [] |
|
|
| for ex in range(self.activated_expert + 1): |
|
|
| for mask, scale, space in zip(self.scaling_mask[ex], self.scale_param[ex], self.space[ex]): |
|
|
| if not mask: |
| break |
|
|
| scale_size = space.shape[1] |
| cropped_scale = scale[:scale_size, :scale_size] |
|
|
| cropped_scale = cropped_scale @ cropped_scale.T |
|
|
| cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(x.device) |
|
|
| k_weights[ex] = k_weights[ex] + k_weights[ex] @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| v_weights[ex] = v_weights[ex] + v_weights[ex] @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
|
|
| cur_selected = selected_expert_id.unsqueeze(-1).unsqueeze(-1) |
|
|
| mask = (cur_selected == ex) |
| inputs[ex] *= mask |
|
|
| inputs[ex] = inputs[ex].to(x.device) |
| q_weight = q_weight.to(x.device) |
| k_weights[ex] = k_weights[ex].to(x.device) |
| v_weights[ex] = v_weights[ex].to(x.device) |
|
|
| qkv = F.linear(inputs[ex], torch.cat([q_weight, k_weights[ex], v_weights[ex]], dim=0)) |
| qkv_outputs.append(qkv) |
|
|
| stacked = torch.stack(qkv_outputs) |
| qkv = torch.sum(stacked, dim=0) |
| qkv = qkv + self.qkv.bias |
| qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
|
|
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if attn_mask is not None: |
| attn += attn_mask.unsqueeze(0) |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x_proj = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x_proj = self.proj(x_proj) |
| x_proj = self.proj_drop(x_proj) |
|
|
| else: |
|
|
| for mask, scale, space in zip(self.scaling_mask[expert_id], self.scale_param[expert_id], self.space[expert_id]): |
|
|
| if not mask: |
| break |
|
|
| scale_size = space.shape[1] |
| cropped_scale = scale[:scale_size, :scale_size] |
|
|
| cropped_scale = cropped_scale @ cropped_scale.T |
|
|
| cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(self.qkv.weight.device) |
|
|
| k_weight = k_weight + k_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| v_weight = v_weight + v_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T |
| |
| qkv = F.linear(x_proj, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if attn_mask is not None: |
| attn += attn_mask.unsqueeze(0) |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x_proj = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x_proj = self.proj(x_proj) |
| x_proj = self.proj_drop(x_proj) |
|
|
| return x, x_proj, probs |
|
|
| |
| class MultiHeadAttention_MultiMaskedLoRA3(MultiHeadAttention_MaskedLoRA): |
| def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): |
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
|
|
| self.cur_task = -1 |
|
|
| self.lora_A_k_list = nn.ModuleList([nn.Linear(self.dim, self.lora_rank, bias=lora_bias) for _ in range(10)]) |
| self.lora_B_k_list = nn.ModuleList([nn.Linear(self.lora_rank, self.dim, bias=lora_bias) for _ in range(10)]) |
| self.lora_A_v_list = nn.ModuleList([nn.Linear(self.dim, self.lora_rank, bias=lora_bias) for _ in range(10)]) |
| self.lora_B_v_list = nn.ModuleList([nn.Linear(self.lora_rank, self.dim, bias=lora_bias) for _ in range(10)]) |
|
|
| self.space_k = [0 for _ in range(10)] |
| self.space_v = [0 for _ in range(10)] |
| self.scale_param = nn.ParameterList([nn.Parameter(self.identity_matrix) for _ in range(10)]) |
|
|
| def init_param(self): |
|
|
| self.cur_task += 1 |
|
|
| i = self.cur_task |
|
|
| nn.init.kaiming_uniform_(self.lora_A_k_list[i].weight, a=math.sqrt(5)) |
| nn.init.kaiming_uniform_(self.lora_A_v_list[i].weight, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B_k_list[i].weight) |
| nn.init.zeros_(self.lora_B_v_list[i].weight) |
|
|
| def merge_weight(self): |
|
|
| print('Not MERGED') |
| return 0 |
|
|
| q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
| k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight |
| v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight |
|
|
| self.apply_lora = False |
|
|
| for exp_id in range(10): |
| for ii, mask, scale_k, scale_v, space_k, space_v in zip([0, 1], self.scaling_mask[exp_id], self.scale_param_k[exp_id], self.scale_param_v[exp_id], self.space_k[exp_id], self.space_v[exp_id]): |
|
|
| if isinstance(space_k, int): |
| break |
|
|
| k_weight = k_weight - k_weight @ space_k.T @ space_k + k_weight @ space_k.T @ scale_k[:space_k.shape[0], :space_k.shape[0]] @ space_k |
| v_weight = v_weight - v_weight @ space_v.T @ space_v + v_weight @ space_v.T @ scale_k[:space_v.shape[0], :space_v.shape[0]] @ space_v |
|
|
| self.space_k[exp_id][ii] = 0 |
|
|
| self.qkv.weight.data = torch.cat([q_weight, k_weight, v_weight], dim=0) |
|
|
| def save_dir(self): |
|
|
| return 0 |
|
|
| self.cur_task += 1 |
|
|
| ''' |
| |
| norm = torch.linalg.matrix_norm(self.lora_B_k.weight @ self.lora_A_k.weight) |
| |
| self.lora_A_k.weight.data = self.lora_A_k.weight.data / norm |
| self.lora_B_k.weight.data = self.lora_B_k.weight.data / norm |
| |
| self.space_k[self.cur_task][0] = self.lora_A_k.weight.data.clone() / norm |
| |
| norm = torch.linalg.matrix_norm(self.lora_B_v.weight @ self.lora_A_v.weight) |
| |
| self.lora_A_v.weight.data = self.lora_A_v.weight.data / norm |
| self.lora_B_v.weight.data = self.lora_B_v.weight.data / norm |
| |
| self.space_v[self.cur_task][0] = self.lora_A_v.weight.data.clone() / norm] |
| ''' |
|
|
| _, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
|
|
| U, _, _ = np.linalg.svd(k_weight.data, full_matrices = False) |
| U, _, _ = np.linalg.svd(U[:, :10], full_matrices = False) |
| orto_proj = U[:, -50:] |
|
|
| self.space_k[self.cur_task][0] = torch.Tensor(orto_proj.T).to(self.qkv.weight.device) |
|
|
| U, _, _ = np.linalg.svd(v_weight.data, full_matrices = False) |
| U, _, _ = np.linalg.svd(U[:, :10], full_matrices = False) |
| orto_proj = U[:, -50:] |
|
|
| self.space_v[self.cur_task][0] = torch.Tensor(orto_proj.T).to(self.qkv.weight.device) |
|
|
| self.scaling_mask[self.cur_task][0] = True |
|
|
| def enable_scale(self, task_id, space): |
| |
| if len(space) == 2: |
| self.space[task_id][0] = space[0] |
| self.space[task_id][1] = space[1] |
| self.scaling_mask[task_id][0] = True |
| self.scaling_mask[task_id][1] = True |
| elif len(space) == 1: |
| self.space[task_id][0] = space[0] |
| self.scaling_mask[task_id][0] = True |
|
|
| for scale_param_list in self.scale_param: |
| for scale_param in scale_param_list: |
| scale_param = scale_param.to(self.qkv.weight.device) |
|
|
| def save_space(self, task_id, space): |
| self.activated_expert = task_id |
| self.saved_space[task_id].append(space) |
|
|
| def forward(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False): |
| |
| B, N, C = x.shape |
|
|
| if get_input_matrix: |
| self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + B * N) |
| self.n_cur_matrix += B * N |
| |
| q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) |
|
|
| |
| for exp_id in range(10): |
|
|
| break |
|
|
| for mask, scale, space_k, space_v in zip(self.scaling_mask[exp_id], self.scale_param[exp_id], self.space_k[exp_id], self.space_v[exp_id]): |
|
|
| if isinstance(space_k, int): |
| break |
|
|
| cropped_scale = scale[:space_k.shape[0], :space_k.shape[0]] |
| print( |
| round(torch.linalg.norm(k_weight @ space_k.T @ space_k, ord='fro').item(), 2), |
| round(torch.linalg.norm(k_weight @ space_k.T @ cropped_scale @ space_k, ord='fro').item(), 2), |
| round(torch.linalg.norm(self.lora_B_k.weight @ self.lora_A_k.weight @ space_k.T @ space_k, ord='fro').item(), 2), |
| round(torch.linalg.norm(self.lora_B_k.weight @ self.lora_A_k.weight @ space_k.T @ cropped_scale @ space_k, ord='fro').item(), 2), |
| ) |
|
|
| for ii in range(self.cur_task + 1): |
| k_weight = k_weight + self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight |
| v_weight = v_weight + self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight |
|
|
| if not isinstance(self.space_k[ii], int): |
|
|
| space_k = self.space_k[ii] |
| space_v = self.space_v[ii] |
| scale_k = self.scale_param[ii] |
|
|
| |
| scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] |
|
|
| |
| scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T |
|
|
| |
| scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T)) |
|
|
| |
| scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]])) |
|
|
| |
| |
|
|
| use_scale = True |
| if use_scale: |
| |
| |
| dir_k = space_k |
| k_weight = k_weight - k_weight @ space_k.T @ space_k + k_weight @ dir_k.T @ scalee @ dir_k |
|
|
| |
| dir_v = space_v |
|
|
| v_weight = v_weight - v_weight @ space_v.T @ space_v + v_weight @ dir_v.T @ scalee @ dir_v |
| else: |
| pass |
| |
|
|
| |
| |
| |
|
|
| qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if attn_mask is not None: |
| attn += attn_mask.unsqueeze(0) |
|
|
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| |
| if register_hook: |
| self.save_attention_map(attn) |
| attn.register_hook(self.save_attn_gradients) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| return x, x, probs |
|
|
|
|
| |
| class Mlp(nn.Module): |
| """ MLP as used in Vision Transformer, MLP-Mixer and related networks |
| """ |
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features) |
| self.act = act_layer() |
| self.fc2 = nn.Linear(hidden_features, out_features) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
| |
| class ResidualAttentionBlock(nn.Module): |
| def __init__(self, |
| d_model: int, |
| n_head: int, |
| mlp_ratio: float = 4., |
| qkv_bias: bool = True, |
| qk_scale: float = None, |
| attn_drop: float = 0., |
| proj_drop: float = 0., |
| drop_path: float = 0., |
| attn_layer = MultiHeadAttention, |
| act_layer = nn.GELU, |
| norm_layer = nn.LayerNorm, |
| norm_layer_eps = 1e-5, |
| attn_mask: torch.Tensor = None, |
| text_or_image=None, |
| |
| lora_rank: int = 0, |
| lora_bias: bool = False |
| ): |
| super().__init__() |
|
|
| if attn_layer == MultiHeadAttention: |
| self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop) |
| elif attn_layer == MultiHeadAttention_LoRA: |
| self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| elif attn_layer == MultiHeadAttention_SDLoRA: |
| self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| elif attn_layer == MultiHeadAttention_LoRA_Sub: |
| self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| elif attn_layer == MultiHeadAttention_MaskedLoRA: |
| self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| elif attn_layer == MultiHeadAttention_MultiMaskedLoRA: |
| self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| elif attn_layer == MultiHeadAttention_CL_LoRA: |
| self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| else: |
| assert 0, f'{attn_layer} not Implemented' |
| |
| self.ln_1 = norm_layer(d_model, eps=norm_layer_eps) |
| self.mlp = Mlp(d_model, int(d_model * mlp_ratio), act_layer=act_layer) |
| self.ln_2 = norm_layer(d_model, eps=norm_layer_eps) |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| self.attn_mask = attn_mask |
| self.text_or_image = text_or_image |
| |
| def attention(self, x: torch.Tensor, **kwargs): |
| self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None |
| |
| x = x.permute(1, 0, 2) |
| attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) |
| attn = attn.permute(1, 0, 2) |
|
|
| return attn |
|
|
| def forward(self, x: torch.Tensor, **kwargs): |
|
|
| x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) |
| x = x + self.drop_path(self.mlp(self.ln_2(x))) |
|
|
| return x |
|
|
| class ResidualAttentionBlock_MLP(ResidualAttentionBlock): |
| def __init__(self, |
| d_model: int, |
| n_head: int, |
| mlp_ratio: float = 4., |
| qkv_bias: bool = True, |
| qk_scale: float = None, |
| attn_drop: float = 0., |
| proj_drop: float = 0., |
| drop_path: float = 0., |
| attn_layer = MultiHeadAttention, |
| act_layer = nn.GELU, |
| norm_layer = nn.LayerNorm, |
| attn_mask: torch.Tensor = None, |
| text_or_image=None, |
| |
| lora_rank: int = 0, |
| lora_bias: bool = False, |
| ): |
| super().__init__( |
| d_model, |
| n_head, |
| mlp_ratio, |
| qkv_bias, |
| qk_scale, |
| attn_drop, |
| proj_drop, |
| drop_path, |
| attn_layer, |
| act_layer, |
| norm_layer, |
| attn_mask, |
| text_or_image) |
|
|
| self.ffn_num = 64 |
| self.adaptmlp = Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num, |
| init_option='lora', adapter_scalar=0.1, adapter_layernorm_option='none') |
|
|
| self.lora_feature = None |
| |
| def attention(self, x: torch.Tensor, **kwargs): |
| self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None |
| |
| x = x.permute(1, 0, 2) |
| attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) |
| attn = attn.permute(1, 0, 2) |
|
|
| return attn |
|
|
| def forward(self, x: torch.Tensor, compute_lora_feat = False, **kwargs): |
| |
| x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) |
|
|
| x_re = x.permute(1, 0, 2) |
| adapt_x = self.adaptmlp(x_re, add_residual=False) |
| adapt_x = adapt_x.permute(1, 0, 2) |
|
|
| x = x + self.drop_path(self.mlp(self.ln_2(x)) + adapt_x) |
|
|
| if compute_lora_feat: |
| self.lora_feature = adapt_x.detach().cpu() |
|
|
| return x |
|
|
| class ResidualAttentionBlock_MaskedMLP(ResidualAttentionBlock): |
| def __init__(self, |
| d_model: int, |
| n_head: int, |
| mlp_ratio: float = 4., |
| qkv_bias: bool = True, |
| qk_scale: float = None, |
| attn_drop: float = 0., |
| proj_drop: float = 0., |
| drop_path: float = 0., |
| attn_layer = MultiHeadAttention, |
| act_layer = nn.GELU, |
| norm_layer = nn.LayerNorm, |
| attn_mask: torch.Tensor = None, |
| text_or_image=None, |
| |
| lora_rank: int = 0, |
| lora_bias: bool = False, |
| ): |
| super().__init__( |
| d_model, |
| n_head, |
| mlp_ratio, |
| qkv_bias, |
| qk_scale, |
| attn_drop, |
| proj_drop, |
| drop_path, |
| attn_layer, |
| act_layer, |
| norm_layer, |
| attn_mask, |
| text_or_image) |
|
|
| self.ffn_num = 64 |
| self.adaptmlp = MaskedAdapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num, |
| init_option='lora', adapter_scalar=0.1, adapter_layernorm_option='none') |
|
|
| def attention(self, x: torch.Tensor, **kwargs): |
| self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None |
| |
| x = x.permute(1, 0, 2) |
| attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) |
| attn = attn.permute(1, 0, 2) |
|
|
| return attn |
|
|
| def forward(self, x: torch.Tensor, compute_input_matrix = False, **kwargs): |
| |
| x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) |
|
|
| x_re = x.permute(1, 0, 2) |
| adapt_x = self.adaptmlp(x_re, add_residual=False, compute_input_matrix=compute_input_matrix) |
| adapt_x = adapt_x.permute(1, 0, 2) |
|
|
| x = x + self.drop_path(self.mlp(self.ln_2(x)) + adapt_x) |
|
|
| return x |
|
|
| class ResidualAttentionBlock_MoE_MLP(ResidualAttentionBlock): |
| def __init__(self, |
| d_model: int, |
| n_head: int, |
| mlp_ratio: float = 4., |
| qkv_bias: bool = True, |
| qk_scale: float = None, |
| attn_drop: float = 0., |
| proj_drop: float = 0., |
| drop_path: float = 0., |
| attn_layer = MultiHeadAttention, |
| act_layer = nn.GELU, |
| norm_layer = nn.LayerNorm, |
| attn_mask: torch.Tensor = None, |
| text_or_image=None, |
| |
| lora_rank: int = 0, |
| lora_bias: bool = False, |
| |
| step: int = 0, |
| experts_num: int = 0, |
| top_k: int = 0, |
| noisy_gating: bool = True |
| ): |
| super().__init__( |
| d_model, |
| n_head, |
| mlp_ratio, |
| qkv_bias, |
| qk_scale, |
| attn_drop, |
| proj_drop, |
| drop_path, |
| attn_layer, |
| act_layer, |
| norm_layer, |
| attn_mask, |
| text_or_image) |
|
|
| assert top_k <= experts_num |
|
|
| self.register_buffer("mean", torch.tensor([0.0])) |
| self.register_buffer("std", torch.tensor([1.0])) |
| self.step = step |
| self.top_k = top_k |
| self.noisy_gating = noisy_gating |
|
|
| self.ffn_num = 64 |
| self.experts_num = experts_num |
| self.softmax = nn.Softmax(1) |
| self.softplus = nn.Softplus() |
| |
| self.router_list = nn.ParameterList([ |
| nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True) for _ in range(self.step) |
| ]) |
| self.w_noise_list = nn.ParameterList([ |
| nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True) for _ in range(self.step) |
| ]) |
|
|
| self.adaptmlp_list = nn.ModuleList([ |
| Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num, |
| init_option='lora', |
| adapter_scalar=0.1, |
| adapter_layernorm_option='none') |
| for _ in range(self.experts_num) |
| ]) |
|
|
| self.lora_feature = None |
| |
| |
| def attention(self, x: torch.Tensor, **kwargs): |
| self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
| |
| x = x.permute(1, 0, 2) |
| attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) |
| attn = attn.permute(1, 0, 2) |
|
|
| return attn |
|
|
| def cv_squared(self, x): |
| """The squared coefficient of variation of a sample. |
| Useful as a loss to encourage a positive distribution to be more uniform. |
| Epsilons added for numerical stability. |
| Returns 0 for an empty Tensor. |
| Args: |
| x: a `Tensor`. |
| Returns: |
| a `Scalar`. |
| """ |
| eps = 1e-10 |
| |
|
|
| if x.shape[0] == 1: |
| return torch.tensor([0], device=x.device, dtype=x.dtype) |
| return x.float().var() / (x.float().mean()**2 + eps) |
|
|
| def _gates_to_load(self, gates): |
| """Compute the true load per expert, given the gates. |
| The load is the number of examples for which the corresponding gate is >0. |
| Args: |
| gates: a `Tensor` of shape [batch_size, n] |
| Returns: |
| a float32 `Tensor` of shape [n] |
| """ |
| return (gates > 0).sum(0) |
|
|
| def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): |
| """Helper function to NoisyTopKGating. |
| Computes the probability that value is in top k, given different random noise. |
| This gives us a way of backpropagating from a loss that balances the number |
| of times each expert is in the top k experts per example. |
| In the case of no noise, pass in None for noise_stddev, and the result will |
| not be differentiable. |
| Args: |
| clean_values: a `Tensor` of shape [batch, n]. |
| noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus |
| normally distributed noise with standard deviation noise_stddev. |
| noise_stddev: a `Tensor` of shape [batch, n], or None |
| noisy_top_values: a `Tensor` of shape [batch, m]. |
| "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 |
| Returns: |
| a `Tensor` of shape [batch, n]. |
| """ |
| |
| batch = clean_values.size(0) |
| m = noisy_top_values.size(1) |
| top_values_flat = noisy_top_values.flatten() |
|
|
| threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.top_k |
| threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) |
| is_in = torch.gt(noisy_values, threshold_if_in) |
| threshold_positions_if_out = threshold_positions_if_in - 1 |
| threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) |
| |
| normal = Normal(self.mean, self.std) |
| |
|
|
| prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) |
| prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) |
| prob = torch.where(is_in, prob_if_in, prob_if_out) |
| return prob |
|
|
| def noisy_top_k_gating(self, x, train, w_gate, w_noise, noise_epsilon=1e-2): |
| """Noisy top-k gating. |
| See paper: https://arxiv.org/abs/1701.06538. |
| Args: |
| x: input Tensor with shape [batch_size, input_size] |
| train: a boolean - we only add noise at training time. |
| noise_epsilon: a float |
| Returns: |
| gates: a Tensor with shape [batch_size, num_experts] |
| load: a Tensor with shape [num_experts] |
| """ |
|
|
| clean_logits = x @ w_gate.to(x) |
|
|
| if self.noisy_gating and train: |
| raw_noise_stddev = x @ w_noise.to(x) |
| noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon)) |
| noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) |
| logits = noisy_logits |
| else: |
| logits = clean_logits |
| |
| top_logits, top_indices = logits.topk(min(self.top_k + 1, self.experts_num), dim=1) |
| top_k_logits = top_logits[:, :self.top_k] |
| top_k_indices = top_indices[:, :self.top_k] |
| top_k_gates = self.softmax(top_k_logits) |
| zeros = torch.zeros_like(logits) |
| gates = zeros.scatter(1, top_k_indices, top_k_gates) |
| |
| |
| |
| |
| return gates, None |
|
|
| def forward(self, x: torch.Tensor, compute_lora_feat=False, **kwargs): |
| |
| x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) |
|
|
| x_re = x.permute(1, 0, 2)[:, 0, :] |
| gates, load = self.noisy_top_k_gating(x_re, self.training, self.router_list[0], |
| self.w_noise_list[0]) |
| |
| dispatcher = SparseDispatcher(self.experts_num, gates) |
| expert_inputs = dispatcher.dispatch(x.permute(1, 0, 2).view(x.shape[1], -1)) |
|
|
| expert_outputs = [self.adaptmlp_list[i](expert_inputs[i].view(expert_inputs[i].shape[0], |
| x.shape[0], x.shape[2]).to(x), add_residual=False) |
| for i in range(self.experts_num)] |
|
|
| expert_outputs = [out.view(out.shape[0], -1) for out in expert_outputs if out.shape[0] > 0] |
|
|
| y = dispatcher.combine(expert_outputs) |
| y = y.view(x.shape[1], x.shape[0], x.shape[2]) |
| x = x + self.drop_path(self.mlp(self.ln_2(x)) + y.permute(1, 0, 2)) |
|
|
| return x |
|
|
| class ResidualAttentionBlock_MoE_Proj(ResidualAttentionBlock): |
| def __init__(self, |
| d_model: int, |
| n_head: int, |
| mlp_ratio: float = 4., |
| qkv_bias: bool = True, |
| qk_scale: float = None, |
| attn_drop: float = 0., |
| proj_drop: float = 0., |
| drop_path: float = 0., |
| attn_layer = MultiHeadAttention, |
| act_layer = nn.GELU, |
| norm_layer = nn.LayerNorm, |
| attn_mask: torch.Tensor = None, |
| text_or_image=None, |
| |
| lora_rank: int = 0, |
| lora_bias: bool = False, |
| |
| experts_num=0, |
| ): |
| super().__init__() |
|
|
| if isinstance(attn_layer, str): |
| try: |
| attn_layer = globals()[attn_layer] |
| except KeyError: |
| print(f'{attn_layer} not found, using default MultiHeadAttention') |
| attn_layer = MultiHeadAttention |
|
|
| if isinstance(act_layer, str): |
| try: |
| act_layer = globals()[act_layer] |
| except KeyError: |
| print(f'{act_layer} not found, using default nn.GELU') |
| act_layer = nn.GELU |
| |
| if isinstance(norm_layer, str): |
| try: |
| norm_layer = globals()[norm_layer] |
| except KeyError: |
| print(f'{norm_layer} not found, using default nn.LayerNorm') |
| norm_layer = nn.LayerNorm |
|
|
| self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop) |
| self.ln_1 = norm_layer(d_model) |
| self.mlp = Mlp(d_model, int(d_model * mlp_ratio), act_layer=act_layer) |
| self.ln_2 = norm_layer(d_model) |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| self.attn_mask = attn_mask |
| self.is_train = True |
| |
| if experts_num > 1: |
| self.register_buffer("mean", torch.tensor([0.0])) |
| self.register_buffer("std", torch.tensor([1.0])) |
| self.step = 1 |
| else: |
| self.step = 0 |
| self.top_k = 2 |
| self.ffn_num = 64 |
| self.experts_num = experts_num |
| self.softmax = nn.Softmax(1) |
| self.softplus = nn.Softplus() |
| self.noisy_gating = True |
| self.text_or_image = text_or_image |
| self.router_list = nn.ParameterList() |
| self.w_noise_list = nn.ParameterList() |
|
|
| for i in range(self.step): |
| self.router_list.append(nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True)) |
| self.w_noise_list.append(nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True)) |
| |
| self.adaptmlp_list = nn.ModuleList() |
| for i in range(self.experts_num): |
| self.adaptmlp_list.append(Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num, |
| init_option='lora', |
| adapter_scalar=0.1, |
| adapter_layernorm_option='none', |
| )) |
|
|
| self.lora_feature = None |
| |
| def attention(self, x: torch.Tensor, **kwargs): |
| self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
| |
| x = x.permute(1, 0, 2) |
| attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) |
| attn = attn.permute(1, 0, 2) |
|
|
| return attn |
|
|
| def cv_squared(self, x): |
| """The squared coefficient of variation of a sample. |
| Useful as a loss to encourage a positive distribution to be more uniform. |
| Epsilons added for numerical stability. |
| Returns 0 for an empty Tensor. |
| Args: |
| x: a `Tensor`. |
| Returns: |
| a `Scalar`. |
| """ |
| eps = 1e-10 |
| |
|
|
| if x.shape[0] == 1: |
| return torch.tensor([0], device=x.device, dtype=x.dtype) |
| return x.float().var() / (x.float().mean()**2 + eps) |
|
|
| def _gates_to_load(self, gates): |
| """Compute the true load per expert, given the gates. |
| The load is the number of examples for which the corresponding gate is >0. |
| Args: |
| gates: a `Tensor` of shape [batch_size, n] |
| Returns: |
| a float32 `Tensor` of shape [n] |
| """ |
| return (gates > 0).sum(0) |
|
|
| def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): |
| """Helper function to NoisyTopKGating. |
| Computes the probability that value is in top k, given different random noise. |
| This gives us a way of backpropagating from a loss that balances the number |
| of times each expert is in the top k experts per example. |
| In the case of no noise, pass in None for noise_stddev, and the result will |
| not be differentiable. |
| Args: |
| clean_values: a `Tensor` of shape [batch, n]. |
| noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus |
| normally distributed noise with standard deviation noise_stddev. |
| noise_stddev: a `Tensor` of shape [batch, n], or None |
| noisy_top_values: a `Tensor` of shape [batch, m]. |
| "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 |
| Returns: |
| a `Tensor` of shape [batch, n]. |
| """ |
| |
| batch = clean_values.size(0) |
| m = noisy_top_values.size(1) |
| top_values_flat = noisy_top_values.flatten() |
|
|
| threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.top_k |
| threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) |
| is_in = torch.gt(noisy_values, threshold_if_in) |
| threshold_positions_if_out = threshold_positions_if_in - 1 |
| threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) |
| |
| normal = Normal(self.mean, self.std) |
| |
|
|
| prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) |
| prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) |
| prob = torch.where(is_in, prob_if_in, prob_if_out) |
| return prob |
|
|
| def noisy_top_k_gating(self, x, train, w_gate, w_noise, noise_epsilon=1e-2): |
| """Noisy top-k gating. |
| See paper: https://arxiv.org/abs/1701.06538. |
| Args: |
| x: input Tensor with shape [batch_size, input_size] |
| train: a boolean - we only add noise at training time. |
| noise_epsilon: a float |
| Returns: |
| gates: a Tensor with shape [batch_size, num_experts] |
| load: a Tensor with shape [num_experts] |
| """ |
|
|
| clean_logits = x @ w_gate.to(x) |
| if self.noisy_gating and train: |
| raw_noise_stddev = x @ w_noise.to(x) |
| noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon)) |
| noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) |
| logits = noisy_logits |
| else: |
| logits = clean_logits |
| |
| top_logits, top_indices = logits.topk(min(self.top_k + 1, self.experts_num), dim=1) |
| top_k_logits = top_logits[:, :self.top_k] |
| top_k_indices = top_indices[:, :self.top_k] |
| top_k_gates = self.softmax(top_k_logits) |
| zeros = torch.zeros_like(logits) |
| gates = zeros.scatter(1, top_k_indices, top_k_gates) |
| |
| |
| |
| |
| return gates, None |
|
|
| def forward(self, x: torch.Tensor, **kwargs): |
| |
| x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) |
|
|
| if self.experts_num == 0: |
|
|
| x = x + self.drop_path(self.mlp(self.ln_2(x))) |
|
|
| elif self.experts_num == 1: |
|
|
| x_re = x.permute(1, 0, 2) |
| adapt_x = self.adaptmlp_list[0](x_re, add_residual=False) |
| adapt_x = adapt_x.permute(1, 0, 2) |
|
|
| x = x + self.drop_path(self.mlp(self.ln_2(x)) + adapt_x) |
|
|
| if compute_lora_feat: |
| self.lora_feature = adapt_x.detach().cpu() |
|
|
| else: |
|
|
| x_re = x.permute(1, 0, 2)[:, 0, :] |
| gates, load = self.noisy_top_k_gating(x_re, self.is_train, self.router_list[0], |
| self.w_noise_list[0]) |
| |
| dispatcher = SparseDispatcher(self.experts_num, gates) |
| expert_inputs = dispatcher.dispatch(x.permute(1, 0, 2).view(x.shape[1], -1)) |
|
|
| expert_outputs = [self.adaptmlp_list[i](expert_inputs[i].view(expert_inputs[i].shape[0], |
| x.shape[0], x.shape[2]).to(x), add_residual=False) |
| for i in range(self.experts_num)] |
|
|
| expert_outputs = [out.view(out.shape[0], -1) for out in expert_outputs if out.shape[0] > 0] |
|
|
| y = dispatcher.combine(expert_outputs) |
| y = y.view(x.shape[1], x.shape[0], x.shape[2]) |
| x = x + self.drop_path(self.mlp(self.ln_2(x)) + y.permute(1, 0, 2)) |
|
|
| return x |
|
|
| class ResidualAttentionBiBlock(nn.Module): |
| def __init__(self, |
| d_model: int, |
| n_head: int, |
| mlp_ratio: float = 4., |
| qkv_bias: bool = True, |
| qk_scale: float = None, |
| attn_drop: float = 0., |
| proj_drop: float = 0., |
| drop_path: float = 0., |
| attn_layer = MultiHeadAttention, |
| act_layer = nn.GELU, |
| norm_layer = nn.LayerNorm, |
| attn_mask: torch.Tensor = None, |
| text_or_image=None, |
| |
| lora_rank: int = 0, |
| lora_bias: bool = False |
| ): |
| super().__init__() |
|
|
| if attn_layer == MultiHeadAttention: |
| self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop) |
| elif attn_layer == MultiHeadAttention_LoRA: |
| self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| elif attn_layer == MultiHeadAttention_MaskedLoRA: |
| self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| elif attn_layer == MultiHeadAttention_MultiMaskedLoRA or attn_layer == MultiHeadAttention_MultiMaskedLoRA3 or attn_layer == MultiHeadAttention_MaskedLoRA1: |
| self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) |
| else: |
| assert 0, f'{attn_layer} not Implemented' |
| |
| self.ln_1 = norm_layer(d_model) |
| self.mlp = Mlp(d_model, int(d_model * mlp_ratio), act_layer=act_layer) |
| self.ln_2 = norm_layer(d_model) |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| self.attn_mask = attn_mask |
| self.text_or_image = text_or_image |
| |
| def attention(self, x: torch.Tensor, x_proj, probs, **kwargs): |
|
|
| self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None |
| |
| x, x_proj = x.permute(1, 0, 2), x_proj.permute(1, 0, 2) |
| attn, attn_proj, probs = self.attn(x, x_proj, probs, attn_mask=self.attn_mask, **kwargs) |
| attn, attn_proj = attn.permute(1, 0, 2), attn_proj.permute(1, 0, 2) |
|
|
| return attn, attn_proj, probs |
|
|
| def forward(self, x: torch.Tensor, x_proj, probs, **kwargs): |
| |
| attn, attn_proj, probs = self.attention(self.ln_1(x), self.ln_1(x_proj), probs, **kwargs) |
|
|
| x = x + self.drop_path(attn) |
| x_proj = x_proj + self.drop_path(attn_proj) |
|
|
| x = x + self.drop_path(self.mlp(self.ln_2(x))) |
| x_proj = x_proj + self.drop_path(self.mlp(self.ln_2(x_proj))) |
|
|
| return x, x_proj, probs |
|
|
| |
| class Transformer(nn.Module): |
| def __init__(self, |
| width: int, |
| layers: int, |
| heads: int, |
| block_layer = ResidualAttentionBlock, |
| attn_layer = MultiHeadAttention, |
| act_layer = nn.GELU, |
| norm_layer = nn.LayerNorm, |
| attn_mask: torch.Tensor = None, |
| text_or_image=None, |
| **kwargs |
| ): |
| super().__init__() |
| self.width = width |
| self.layers = layers |
|
|
| if isinstance(block_layer, str): |
| try: |
| block_layer = globals()[block_layer] |
| except KeyError: |
| print(f'{block_layer} not found, using default ResidualAttentionBlock') |
| block_layer = ResidualAttentionBlock |
|
|
| if isinstance(attn_layer, str): |
| try: |
| attn_layer = globals()[attn_layer] |
| except KeyError: |
| print(f'{attn_layer} not found, using default MultiHeadAttention') |
| attn_layer = MultiHeadAttention |
|
|
| if isinstance(act_layer, str): |
| try: |
| act_layer = globals()[act_layer] |
| except KeyError: |
| print(f'{act_layer} not found, using default nn.GELU') |
| act_layer = nn.GELU |
| |
| if isinstance(norm_layer, str): |
| try: |
| norm_layer = globals()[norm_layer] |
| except KeyError: |
| print(f'{norm_layer} not found, using default nn.LayerNorm') |
| norm_layer = nn.LayerNorm |
|
|
| self.blocks = nn.ModuleList([ |
| block_layer( |
| d_model=width, |
| n_head=heads, |
| attn_layer=attn_layer, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| attn_mask=attn_mask, |
| text_or_image=text_or_image, |
| **kwargs) |
| for _ in range(layers)]) |
|
|
| def forward(self, x: torch.Tensor, l2p_prompt=None, l2p_e_prompt_layer_idx=[], **kwargs): |
|
|
| prompt_counter = -1 |
| for i, block in enumerate(self.blocks): |
| if l2p_prompt is not None and (i in l2p_e_prompt_layer_idx): |
| prompt_counter += 1 |
| batched_prompt = l2p_prompt[prompt_counter] |
| batched_prompt = batched_prompt.permute(1, 0, 2) |
| x = torch.cat([batched_prompt, x], dim=0) |
|
|
| x = block(x, **kwargs) |
| |
| return x |
|
|
| class Transformer_Proj(Transformer): |
| def __init__(self, |
| width: int, |
| layers: int, |
| heads: int, |
| block_layer = ResidualAttentionBlock, |
| attn_layer = MultiHeadAttention, |
| act_layer = nn.GELU, |
| norm_layer = nn.LayerNorm, |
| attn_mask: torch.Tensor = None, |
| text_or_image=None, |
| **kwargs |
| ): |
| super().__init__(width, layers, heads, block_layer, attn_layer, act_layer, norm_layer, attn_mask, text_or_image, **kwargs) |
| self.probs = [] |
|
|
| def forward(self, x: torch.Tensor, **kwargs): |
| |
| x_proj = x.clone() |
| self.probs = [] |
| for i, block in enumerate(self.blocks): |
| x, x_proj, self.probs = block(x, x_proj, self.probs, **kwargs) |
|
|
| return x_proj |
|
|
| class Transformer_CL_LoRA(Transformer): |
| def __init__(self, |
| width: int, |
| layers: int, |
| heads: int, |
| block_layer = ResidualAttentionBlock, |
| attn_layer = MultiHeadAttention, |
| act_layer = nn.GELU, |
| norm_layer = nn.LayerNorm, |
| attn_mask: torch.Tensor = None, |
| text_or_image=None, |
| **kwargs |
| ): |
| super().__init__(width, layers, heads, block_layer, attn_layer, act_layer, norm_layer, attn_mask, text_or_image, **kwargs) |
|
|
| def forward(self, x, adapt, prompt, rank_prompt, block_weight, **kwargs): |
|
|
| for idx, blk in enumerate(self.blocks): |
|
|
| if idx >= 6: |
| x = blk( |
| x, |
| adapt = adapt[idx], |
| prompt = prompt, |
| rank_prompt = rank_prompt, |
| block_weight = block_weight[:, idx - 6], |
| **kwargs |
| ) |
| else: |
| x = blk( |
| x, |
| adapt = adapt[idx], |
| prompt = prompt, |
| rank_prompt = rank_prompt, |
| block_weight = None, |
| **kwargs |
| ) |
|
|
| return x |
|
|
| |
| class VisualTransformer(nn.Module): |
| def __init__(self, |
| img_size: int, |
| patch_size: int, |
| in_chans: int = 3, |
| width: int = 768, |
| depth: int = 12, |
| heads: int = 8, |
| output_dim: int = 512, |
| text_or_image: str = None, |
| **kwargs |
| ): |
| super().__init__() |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.in_chans = in_chans |
| self.width = width |
| self.depth = depth |
| self.heads = heads |
| self.output_dim = output_dim |
|
|
| self.conv1 = nn.Conv2d(in_channels=in_chans, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) |
|
|
| scale = width ** -0.5 |
| self.class_embedding = nn.Parameter(scale * torch.randn(width)) |
| self.positional_embedding = nn.Parameter(scale * torch.randn((img_size // patch_size) ** 2 + 1, width)) |
| self.ln_pre = LayerNorm(width) |
|
|
| self.transformer = Transformer(width, depth, heads, text_or_image=text_or_image, **kwargs) |
|
|
| self.ln_post = LayerNorm(width) |
| self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) |
|
|
| def forward(self, x: torch.Tensor, **kwargs): |
|
|
| x = self.conv1(x) |
| x = x.reshape(x.shape[0], x.shape[1], -1) |
| x = x.permute(0, 2, 1) |
|
|
| x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) |
| x = x + self.positional_embedding.to(x.dtype) |
| x = self.ln_pre(x) |
|
|
| x = x.permute(1, 0, 2) |
| x = self.transformer(x, **kwargs) |
| x = x.permute(1, 0, 2) |
|
|
| x = self.ln_post(x[:, 0, :]) |
|
|
| if self.proj is not None: |
| x = x @ self.proj |
|
|
| return x |
|
|
| |
| class VisionTransformer(nn.Module): |
| """ Vision Transformer |
| A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - |
| https://arxiv.org/abs/2010.11929 |
| """ |
| def __init__(self, |
| img_size=224, |
| patch_size=16, |
| in_chans=3, |
| num_classes=1000, |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| attn_layer=MultiHeadAttention, |
| mlp_ratio=4., |
| qkv_bias=True, |
| qk_scale=None, |
| representation_size=None, |
| drop_rate=0., |
| attn_drop_rate=0., |
| drop_path_rate=0., |
| norm_layer=nn.LayerNorm, |
| ckpt_layer=0, |
| transformer_layer=Transformer, |
| **kwargs): |
| """ |
| Args: |
| img_size (int, tuple): input image size |
| patch_size (int, tuple): patch size |
| in_chans (int): number of input channels |
| num_classes (int): number of classes for classification head |
| embed_dim (int): embedding dimension |
| depth (int): depth of transformer |
| num_heads (int): number of attention heads |
| mlp_ratio (int): ratio of mlp hidden dim to embedding dim |
| qkv_bias (bool): enable bias for qkv if True |
| qk_scale (float): override default qk scale of head_dim ** -0.5 if set |
| representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set |
| drop_rate (float): dropout rate |
| attn_drop_rate (float): attention dropout rate |
| drop_path_rate (float): stochastic depth rate |
| norm_layer: (nn.Module): normalization layer |
| """ |
| super().__init__() |
|
|
| self.num_features = self.embed_dim = embed_dim |
| self.num_heads = num_heads |
|
|
| self.patch_embed = PatchEmbed( |
| img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) |
|
|
| num_patches = self.patch_embed.num_patches |
|
|
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) |
| self.pos_drop = nn.Dropout(p=drop_rate) |
| if transformer_layer == 'Transformer_Proj': |
| self.transformer = Transformer_Proj(embed_dim, depth, num_heads, text_or_image='image', attn_layer=attn_layer, norm_layer=norm_layer, **kwargs) |
| elif transformer_layer == 'Transformer_CL_LoRA': |
| self.transformer = Transformer_CL_LoRA(embed_dim, depth, num_heads, text_or_image='image', attn_layer=attn_layer, norm_layer=norm_layer, **kwargs) |
| else: |
| self.transformer = Transformer(embed_dim, depth, num_heads, text_or_image='image', attn_layer=attn_layer, norm_layer=norm_layer, **kwargs) |
| self.norm = partial(nn.LayerNorm, eps=1e-6)(embed_dim) |
|
|
| trunc_normal_(self.pos_embed, std=.02) |
| trunc_normal_(self.cls_token, std=.02) |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| trunc_normal_(m.weight, std=.02) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| return {'pos_embed', 'cls_token'} |
|
|
| def forward(self, x, register_blk=-1, prompt=None, prompt_flag='', q=None, train=False, task_id=-1, cls_features=None, **kwargs): |
|
|
| B = x.shape[0] |
| x = self.patch_embed(x) |
|
|
| if prompt_flag == 'l2p': |
|
|
| batched_prompt = None |
| e_prompt_layer_idx = [] |
| if prompt: |
|
|
| num_prompted_layers = 1 |
| e_prompt_layer_idx = [0] |
| total_prompt_len = prompt.length * prompt.top_k * len(e_prompt_layer_idx) |
|
|
| batched_prompt, reduce_sim = prompt(x, cls_features=cls_features) |
|
|
| cls_tokens = self.cls_token.expand(B, -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
| |
| x = x + self.pos_embed[:, :x.size(1), :] |
| x = self.pos_drop(x) |
|
|
| x = x.permute(1, 0, 2) |
| x = self.transformer( |
| x, |
| l2p_prompt = batched_prompt, |
| l2p_e_prompt_layer_idx = e_prompt_layer_idx, |
| **kwargs |
| ) |
| x = x.permute(1, 0, 2) |
|
|
| x = self.norm(x) |
|
|
| if prompt: |
| x = x[:, :total_prompt_len] |
| x = x.mean(dim=1) |
| return x, reduce_sim |
| else: |
| return x[:, 0] |
|
|
| else: |
|
|
| cls_tokens = self.cls_token.expand(B, -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
|
|
| x = x + self.pos_embed[:,:x.size(1),:] |
| x = self.pos_drop(x) |
|
|
| |
| prompt_loss = torch.zeros((1,), requires_grad=True).to(x.device) |
| if prompt is not None: |
| for i,blk in enumerate(self.transformer.blocks): |
|
|
| if prompt is not None: |
| if train: |
| p_list, loss, x = prompt.forward(q, i, x, train=True, task_id=task_id) |
| prompt_loss += loss |
| else: |
| p_list, _, x = prompt.forward(q, i, x, train=False, task_id=task_id) |
| else: |
| p_list = None |
|
|
| |
| x = x.permute(1, 0, 2) |
| x = blk(x, register_hook=register_blk==i, prompt=p_list, **kwargs) |
| x = x.permute(1, 0, 2) |
| else: |
|
|
| x = x.permute(1, 0, 2) |
| x = self.transformer(x, **kwargs) |
| x = x.permute(1, 0, 2) |
|
|
| x = self.norm(x) |
| return x, prompt_loss |
|
|
| @torch.jit.ignore() |
| def load_pretrained(self, checkpoint_path, prefix=''): |
| _load_weights(self, checkpoint_path, prefix) |
|
|
| class VisionTransformer_CL_LoRA(VisionTransformer): |
| """ Vision Transformer |
| A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - |
| https://arxiv.org/abs/2010.11929 |
| """ |
|
|
| class Adapter_lora(nn.Module): |
| def __init__(self, |
| config=None, |
| d_model=None, |
| bottleneck=None, |
| dropout=0.0, |
| init_option="bert", |
| adapter_scalar="1.0", |
| adapter_layernorm_option="in"): |
| super().__init__() |
|
|
| self.n_embd = config.d_model if d_model is None else d_model |
| self.down_size = config.attn_bn if bottleneck is None else bottleneck |
|
|
| self.lora_A = nn.Linear(self.down_size, self.n_embd, bias=False) |
| self.lora_B = nn.Linear(self.n_embd, self.down_size, bias=False) |
|
|
| random_matrix = torch.rand(self.n_embd, self.down_size) |
| q, r = torch.linalg.qr(random_matrix) |
| with torch.no_grad(): |
| self.lora_B.weight.copy_(q.T) |
| scaling_factor = 1. |
| self.lora_B.weight.data *= scaling_factor |
|
|
| if init_option == "bert": |
| raise NotImplementedError |
| elif init_option == "lora": |
| with torch.no_grad(): |
| nn.init.zeros_(self.lora_A.weight) |
| else: |
| raise NotImplementedError |
|
|
| def forward(self, x): |
| inter_x = self.lora_B(x) |
| out = self.lora_A(inter_x) |
| return out |
|
|
| def __init__(self, |
| img_size=224, |
| patch_size=16, |
| in_chans=3, |
| num_classes=1000, |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| attn_layer=MultiHeadAttention, |
| mlp_ratio=4., |
| qkv_bias=True, |
| qk_scale=None, |
| representation_size=None, |
| drop_rate=0., |
| attn_drop_rate=0., |
| drop_path_rate=0., |
| norm_layer=nn.LayerNorm, |
| ckpt_layer=0, |
| transformer_layer=Transformer, |
| **kwargs): |
| """ |
| Args: |
| img_size (int, tuple): input image size |
| patch_size (int, tuple): patch size |
| in_chans (int): number of input channels |
| num_classes (int): number of classes for classification head |
| embed_dim (int): embedding dimension |
| depth (int): depth of transformer |
| num_heads (int): number of attention heads |
| mlp_ratio (int): ratio of mlp hidden dim to embedding dim |
| qkv_bias (bool): enable bias for qkv if True |
| qk_scale (float): override default qk scale of head_dim ** -0.5 if set |
| representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set |
| drop_rate (float): dropout rate |
| attn_drop_rate (float): attention dropout rate |
| drop_path_rate (float): stochastic depth rate |
| norm_layer: (nn.Module): normalization layer |
| """ |
| super().__init__( |
| img_size=img_size, |
| patch_size=patch_size, |
| in_chans=in_chans, |
| num_classes=num_classes, |
| embed_dim=embed_dim, |
| depth=depth, |
| num_heads=num_heads, |
| attn_layer=attn_layer, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| representation_size=representation_size, |
| drop_rate=drop_rate, |
| attn_drop_rate=attn_drop_rate, |
| drop_path_rate=drop_path_rate, |
| norm_layer=norm_layer, |
| ckpt_layer=ckpt_layer, |
| transformer_layer=transformer_layer, |
| **kwargs |
| ) |
|
|
|
|
| cfg_dict = { |
| 'use_distillation': True, |
| 'use_block_weight': True, |
| 'msa_adapt': True, |
| 'msa': [1, 0, 1], |
| 'specfic_pos': [6, 7, 8, 9, 10, 11], |
| 'general_pos': [0, 1, 2, 3, 4, 5], |
| 'ffn_adapt': True, |
| 'ffn_option': 'parallel', |
| 'ffn_adapter_layernorm_option': 'none', |
| 'ffn_adapter_init_option': 'lora', |
| 'ffn_adapter_scalar': '0.1', |
| 'ffn_num': kwargs['lora_rank'], |
| 'd_model': 768, |
| 'vpt_on': False, |
| 'vpt_num': 0, |
| '_device': 'cuda:0' |
| } |
| |
| from types import SimpleNamespace |
|
|
| self.tuning_config = SimpleNamespace(**cfg_dict) |
| self.config = self.tuning_config |
|
|
| self._device = self.tuning_config._device |
| self.msa_adapt = self.tuning_config.msa_adapt |
| self.use_distillation = self.tuning_config.use_distillation |
| self.use_block_weight = self.tuning_config.use_block_weight |
|
|
| self.general_pos = self.tuning_config.general_pos |
| self.specfic_pos = self.tuning_config.specfic_pos |
| self.adapt_pos = self.general_pos + self.specfic_pos |
| self.adapt_pos = sorted(self.adapt_pos) |
|
|
| if self.msa_adapt: |
| self.msa = self.tuning_config.msa |
|
|
| if self.use_distillation: |
| self.old_adapter_list = nn.ModuleList() |
|
|
| if self.use_block_weight: |
| self.block_weight_list = [] |
| self.block_weight = nn.Parameter(torch.randn(3, len(self.specfic_pos))) |
| nn.init.uniform_(self.block_weight, .5, 1.5) |
|
|
| self.adapter_list = [] |
| self.adapter_pos_list = [] |
| self.cur_adapter = nn.ModuleList() |
| self.get_new_adapter_initial_msa() |
|
|
| def forward(self, x, test = False, register_blk=-1, prompt=None, prompt_flag='', q=None, train=False, task_id=-1, cls_features=None, **kwargs): |
|
|
| if not test: |
| output = self.forward_train(x) |
| output = output[:, 0] |
| return output, None |
|
|
| else: |
| features = self.forward_test(x) |
| output = torch.Tensor().to(features[0].device) |
| for x in features: |
| cls = x[:, 0, :] |
| output = torch.cat(( |
| output, |
| cls |
| ), dim=1) |
| return output, None |
|
|
| def forward_train(self, x): |
|
|
| B = x.shape[0] |
| x = self.patch_embed(x) |
|
|
| cls_tokens = self.cls_token.expand(B, -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
|
|
| x = x + self.pos_embed[:,:x.size(1),:] |
| x = self.pos_drop(x) |
|
|
| x = x.permute(1, 0, 2) |
| |
| x = self.transformer( |
| x, |
| adapt = self.cur_adapter, |
| prompt = None, |
| rank_prompt = None, |
| block_weight=self.block_weight) |
| x = x.permute(1, 0, 2) |
| x = self.norm(x) |
|
|
| return x |
|
|
| def forward_test(self, x, use_init_ptm=False): |
| import copy |
| B = x.shape[0] |
| x = self.patch_embed(x) |
|
|
| cls_tokens = self.cls_token.expand(B, -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
| x = x + self.pos_embed |
| x_init = self.pos_drop(x) |
|
|
| features = [] |
| assert self.config.ffn_adapt |
| assert self.adapt_pos == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] |
| assert self.general_pos == [0, 1, 2, 3, 4, 5] |
| assert self.use_block_weight |
|
|
| |
|
|
| for i in range(len(self.adapter_list)): |
| x = copy.deepcopy(x_init) |
|
|
| x = x.permute(1, 0, 2) |
| for idx, blk in enumerate(self.transformer.blocks): |
|
|
| if idx >= 6: |
| x = blk(x, adapt = self.adapter_list[i][idx - 6], prompt = None, rank_prompt = None, |
| block_weight=self.block_weight_list[i][:, idx - 6]) |
| else: |
| x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None) |
| x = x.permute(1, 0, 2) |
|
|
| x = self.norm(x) |
| features.append(x) |
|
|
| x = copy.deepcopy(x_init) |
| x = x.permute(1, 0, 2) |
| for idx, blk in enumerate(self.transformer.blocks): |
|
|
| if idx >= 6: |
| x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, |
| block_weight=self.block_weight[:, idx - 6]) |
| else: |
| x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None) |
| x = x.permute(1, 0, 2) |
|
|
|
|
| x = self.norm(x) |
| features.append(x) |
|
|
| return features |
|
|
| def forward_proto(self, x, adapt_index): |
| assert adapt_index > -1 |
| assert self.config.ffn_adapt |
| assert self.use_block_weight |
|
|
| B = x.shape[0] |
| x = self.patch_embed(x) |
|
|
| cls_tokens = self.cls_token.expand(B, -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
| x = x + self.pos_embed |
| x = self.pos_drop(x) |
|
|
|
|
| if adapt_index < len(self.adapter_list): |
| |
| x = x.permute(1, 0, 2) |
| for idx, blk in enumerate(self.transformer.blocks): |
|
|
| if idx >= 6: |
| x = blk(x, adapt = self.adapter_list[adapt_index][idx - 6], prompt = None, rank_prompt = None, |
| block_weight=self.block_weight_list[adapt_index][:, idx - 6]) |
| else: |
| x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None) |
| x = x.permute(1, 0, 2) |
|
|
| else: |
| |
| x = x.permute(1, 0, 2) |
| for idx, blk in enumerate(self.transformer.blocks): |
|
|
| if idx >= 6: |
| x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, |
| block_weight=self.block_weight[:, idx - 6]) |
| else: |
| x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None) |
| x = x.permute(1, 0, 2) |
|
|
| x = self.norm(x) |
| x = x[:, 0, :] |
|
|
| return x |
|
|
| def forward_general_cls(self, x, t_idx): |
| import copy |
| B = x.shape[0] |
| x = self.patch_embed(x) |
|
|
| cls_tokens = self.cls_token.expand(B, -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
| x = x + self.pos_embed |
| x = self.pos_drop(x) |
|
|
| x_teacher = copy.deepcopy(x) |
|
|
| for j in range(6): |
| x = self.transformer.blocks[j](x, adapt = self.cur_adapter[j]) |
| x_teacher = self.transformer.blocks[j](x_teacher, adapt = self.old_adapter_list[t_idx-1][j]) |
|
|
| x = self.norm(x) |
| output_new = x[:, 0, :] |
|
|
| x_teacher = self.norm(x_teacher) |
| output_teacher= x_teacher[:, 0, :] |
|
|
| return output_new, output_teacher |
|
|
| def get_new_adapter_initial_msa(self): |
|
|
| config = self.config |
| if config.ffn_adapt: |
| for i in range(len(self.adapt_pos)): |
| temp_adapter = nn.ModuleList() |
| for j in self.msa: |
| if j ==1: |
| adapter = VisionTransformer_CL_LoRA.Adapter_lora(self.config, dropout=0.0, bottleneck=config.ffn_num, |
| init_option=config.ffn_adapter_init_option, |
| adapter_scalar=config.ffn_adapter_scalar, |
| adapter_layernorm_option=config.ffn_adapter_layernorm_option, |
| ).to(self._device) |
| else: |
| adapter = nn.Identity() |
| temp_adapter.append(adapter) |
|
|
| self.cur_adapter.append(temp_adapter) |
| self.cur_adapter.requires_grad_(True) |
|
|
| else: |
| print("====Not use adapter===") |
|
|
| def add_adapter_to_list(self): |
| temp_adapter = [] |
| import copy |
| for i in range(len(self.specfic_pos)): |
| temp_pos = self.adapt_pos.index(self.specfic_pos[i]) |
| temp_adapter.append(copy.deepcopy(self.cur_adapter[temp_pos].requires_grad_(False))) |
| self.adapter_list.append(temp_adapter) |
|
|
| if self.use_block_weight: |
| self.block_weight_old = copy.deepcopy(self.block_weight) |
| self.block_weight_list.append(self.block_weight_old.requires_grad_(False)) |
| self.block_weight = nn.Parameter(torch.randn(3, len(self.specfic_pos))) |
| nn.init.uniform_(self.block_weight, .5, 1.5) |
|
|
| self.adapter_pos_list.append(self.adapt_pos) |
|
|
| if self.use_distillation: |
| self.old_adapter_list.append(copy.deepcopy(self.cur_adapter).requires_grad_(False)) |
| if self.msa_adapt: |
| self.get_new_adapter_msa() |
|
|
| def get_new_adapter_msa(self): |
| config = self.config |
|
|
| if config.ffn_adapt: |
| for i in range(len(self.specfic_pos)): |
| pos = self.adapt_pos.index(self.specfic_pos[i]) |
| temp_adapter = nn.ModuleList() |
| for j in self.msa: |
| if j == 1: |
| adapter = VisionTransformer_CL_LoRA.Adapter_lora(self.config, dropout=0.0, bottleneck=config.ffn_num, |
| init_option=config.ffn_adapter_init_option, |
| adapter_scalar=config.ffn_adapter_scalar, |
| adapter_layernorm_option=config.ffn_adapter_layernorm_option, |
| ).to(self._device) |
| adapter.requires_grad_(True) |
| else: |
| adapter = nn.Identity() |
| temp_adapter.append(adapter) |
| self.cur_adapter[pos] = temp_adapter |
|
|
| if len(self.specfic_pos) < 12: |
| self.cur_adapter.requires_grad_(True) |
|
|
| for i in self.adapt_pos: |
| if i in self.general_pos: |
| pos = self.adapt_pos.index(i) |
| for j in range(len(self.msa)): |
| if self.msa[j] == 1: |
| self.cur_adapter[pos][j].lora_B.requires_grad_(False) |
| else: |
| print("====Not use adapter===") |
|
|
| @torch.jit.ignore() |
| def load_pretrained(self, checkpoint_path, prefix=''): |
| _load_weights(self, checkpoint_path, prefix) |
|
|
| @torch.no_grad() |
| def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): |
| """ Load weights from .npz checkpoints for official Google Brain Flax implementation |
| """ |
| import numpy as np |
|
|
| def _n2p(w, t=True): |
| if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: |
| w = w.flatten() |
| if t: |
| if w.ndim == 4: |
| w = w.transpose([3, 2, 0, 1]) |
| elif w.ndim == 3: |
| w = w.transpose([2, 0, 1]) |
| elif w.ndim == 2: |
| w = w.transpose([1, 0]) |
| return torch.from_numpy(w) |
|
|
| w = np.load(checkpoint_path) |
| if not prefix and 'opt/target/embedding/kernel' in w: |
| prefix = 'opt/target/' |
|
|
| if hasattr(model.patch_embed, 'backbone'): |
| |
| backbone = model.patch_embed.backbone |
| stem_only = not hasattr(backbone, 'stem') |
| stem = backbone if stem_only else backbone.stem |
| stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) |
| stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) |
| stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) |
| if not stem_only: |
| for i, stage in enumerate(backbone.stages): |
| for j, block in enumerate(stage.blocks): |
| bp = f'{prefix}block{i + 1}/unit{j + 1}/' |
| for r in range(3): |
| getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) |
| getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) |
| getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) |
| if block.downsample is not None: |
| block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) |
| block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) |
| block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) |
| embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) |
| else: |
| embed_conv_w = adapt_input_conv( |
| model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) |
| model.patch_embed.proj.weight.copy_(embed_conv_w) |
| model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) |
| model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) |
| pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) |
| if pos_embed_w.shape != model.pos_embed.shape: |
| pos_embed_w = resize_pos_embed( |
| pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) |
| model.pos_embed.copy_(pos_embed_w) |
| model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) |
| model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) |
|
|
| for i, block in enumerate(model.blocks.children()): |
| block_prefix = f'{prefix}Transformer/encoderblock_{i}/' |
| mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' |
| block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) |
| block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) |
| block.attn.qkv.weight.copy_(torch.cat([ |
| _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) |
| block.attn.qkv.bias.copy_(torch.cat([ |
| _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) |
| block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) |
| block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) |
| for r in range(2): |
| getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) |
| getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) |
| block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) |
| block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) |
|
|