aydnarda commited on
Commit
05a82cf
·
verified ·
1 Parent(s): 5b0eba9

upload supp files

Browse files
bert_modeling_bert_self_attn_patch.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers.models.bert import modeling_bert
3
+ from typing import Optional, Tuple
4
+ import torch.nn as nn
5
+ import math
6
+
7
+ def patch_bert_self_attn():
8
+
9
+ def bert_self_attn_forward_patched(self,
10
+ hidden_states: torch.Tensor,
11
+ attention_mask: Optional[torch.FloatTensor] = None,
12
+ head_mask: Optional[torch.FloatTensor] = None,
13
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
14
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
15
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
16
+ output_attentions: Optional[bool] = False):
17
+
18
+ mixed_query_layer = self.query(hidden_states)
19
+
20
+ is_cross_attention = encoder_hidden_states is not None
21
+
22
+ if is_cross_attention and past_key_value is not None:
23
+ key_layer = past_key_value[0]
24
+ value_layer = past_key_value[1]
25
+ attention_mask = encoder_attention_mask
26
+ elif is_cross_attention:
27
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
28
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
29
+ attention_mask = encoder_attention_mask
30
+ elif past_key_value is not None:
31
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
32
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
33
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
34
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
35
+ else:
36
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
37
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
38
+
39
+ query_layer = self.transpose_for_scores(mixed_query_layer)
40
+
41
+ use_cache = past_key_value is not None
42
+ if self.is_decoder:
43
+ past_key_value = (key_layer, value_layer)
44
+
45
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
46
+
47
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
48
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
49
+ if use_cache:
50
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
51
+ -1, 1
52
+ )
53
+ else:
54
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
55
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
56
+ distance = position_ids_l - position_ids_r
57
+
58
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
59
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
60
+
61
+ if self.position_embedding_type == "relative_key":
62
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
63
+ attention_scores = attention_scores + relative_position_scores
64
+ elif self.position_embedding_type == "relative_key_query":
65
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
66
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
67
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
68
+
69
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
70
+ attn_scores = attention_scores
71
+ if attention_mask is not None:
72
+ attention_scores = attention_scores + attention_mask
73
+
74
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
75
+
76
+ attention_probs = self.dropout(attention_probs)
77
+
78
+ if head_mask is not None:
79
+ attention_probs = attention_probs * head_mask
80
+
81
+ context_layer = torch.matmul(attention_probs, value_layer)
82
+
83
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
84
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
85
+ context_layer = context_layer.view(new_context_layer_shape)
86
+
87
+ outputs = (context_layer, attn_scores) if output_attentions else (context_layer,)
88
+
89
+ if self.is_decoder:
90
+ outputs = outputs + (past_key_value,)
91
+ return outputs
92
+
93
+ modeling_bert.BertSelfAttention.forward = bert_self_attn_forward_patched
loralib/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .layers import *
2
+ from .utils import *
loralib/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (195 Bytes). View file
 
loralib/__pycache__/layers.cpython-310.pyc ADDED
Binary file (21.8 kB). View file
 
loralib/__pycache__/utils.cpython-310.pyc ADDED
Binary file (5.48 kB). View file
 
loralib/easymultiheadattention.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ """
6
+ Source : https://github.com/KyanChen/MakeMultiHeadNaive/blob/master/main.py
7
+ """
8
+
9
+ class PlainMultiHeadAttention(nn.Module):
10
+ def __init__(
11
+ self,
12
+ existing_mha: nn.MultiheadAttention):
13
+ super().__init__()
14
+
15
+ self.dropout = 0 # this module is not used to retrain the main block
16
+ self.embed_dim = existing_mha.embed_dim
17
+ self.kdim = existing_mha.kdim
18
+ self.vdim = existing_mha.vdim
19
+ self._qkv_same_embed_dim = existing_mha._qkv_same_embed_dim
20
+ self.num_heads = existing_mha.num_heads
21
+ self.batch_first = existing_mha.batch_first
22
+ self.head_dim = existing_mha.head_dim
23
+ self.qkv = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=existing_mha.in_proj_bias is not None)
24
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.out_proj.bias is not None)
25
+
26
+ # Initialize parameters
27
+ with torch.no_grad():
28
+ self.qkv.weight.data.copy_(existing_mha.in_proj_weight.data)
29
+ if self.qkv.bias is not None:
30
+ self.qkv.bias.data.copy_(existing_mha.in_proj_bias.data)
31
+ self.proj.weight.data.copy_(existing_mha.out_proj.weight.data)
32
+ if self.proj.bias is not None:
33
+ self.proj.bias.data.copy_(existing_mha.out_proj.bias.data)
34
+
35
+ self.scaled_dot_product_attention = F.scaled_dot_product_attention
36
+
37
+ def forward(
38
+ self,
39
+ query,
40
+ key,
41
+ value,
42
+ key_padding_mask=None,
43
+ need_weights=True,
44
+ attn_mask=None,
45
+ average_attn_weights=True,
46
+ is_causal=False):
47
+
48
+ if attn_mask is not None and is_causal:
49
+ raise AssertionError("Only allow causal mask or attn_mask")
50
+ is_batched = query.dim() == 3
51
+ key_padding_mask = F._canonical_mask(
52
+ mask=key_padding_mask,
53
+ mask_name="key_padding_mask",
54
+ other_type=F._none_or_dtype(attn_mask),
55
+ other_name="attn_mask",
56
+ target_type=query.dtype
57
+ )
58
+
59
+ if self.batch_first and is_batched:
60
+ if key is value:
61
+ if query is key:
62
+ query = key = value = query.transpose(1, 0)
63
+ else:
64
+ query, key = [x.transpose(1, 0) for x in (query, key)]
65
+ value = key
66
+ else:
67
+ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
68
+
69
+ tgt_len, bsz, embed_dim = query.shape
70
+ src_len, _, _ = key.shape
71
+
72
+ E = query.size(-1)
73
+ qkv = self.qkv(query)
74
+ qkv = qkv.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
75
+ q, k, v = qkv[0], qkv[1], qkv[2]
76
+
77
+ attn_mask = F._canonical_mask(
78
+ mask=attn_mask,
79
+ mask_name="attn_mask",
80
+ other_type=F._none_or_dtype(key_padding_mask),
81
+ other_name="key_padding_mask",
82
+ target_type=q.dtype,
83
+ check_other=False,
84
+ )
85
+
86
+ if attn_mask is not None:
87
+ # ensure attn_mask's dim is 3
88
+ if attn_mask.dim() == 2:
89
+ correct_2d_size = (tgt_len, src_len)
90
+ if attn_mask.shape != correct_2d_size:
91
+ raise RuntimeError(
92
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
93
+ attn_mask = attn_mask.unsqueeze(0)
94
+ elif attn_mask.dim() == 3:
95
+ correct_3d_size = (bsz * self.num_heads, tgt_len, src_len)
96
+ if attn_mask.shape != correct_3d_size:
97
+ raise RuntimeError(
98
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
99
+ else:
100
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
101
+
102
+ if attn_mask is not None:
103
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
104
+ attn_mask = attn_mask.unsqueeze(0)
105
+ else:
106
+ attn_mask = attn_mask.view(bsz, self.num_heads, -1, src_len)
107
+
108
+ dropout_p = self.dropout if self.training else 0.
109
+
110
+ q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
111
+ k = k.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
112
+ v = v.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
113
+ src_len = k.size(1)
114
+ q = q.view(bsz, self.num_heads, tgt_len, self.head_dim)
115
+ k = k.view(bsz, self.num_heads, src_len, self.head_dim)
116
+ v = v.view(bsz, self.num_heads, src_len, self.head_dim)
117
+
118
+ attn_output = self.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
119
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
120
+ attn_output = self.proj(attn_output)
121
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
122
+ if self.batch_first and is_batched:
123
+ return attn_output.transpose(1, 0), None
124
+ return attn_output, None
loralib/layers.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # This code is reconstructed based on loralib (https://github.com/microsoft/LoRA) by Baijiong Lin.
3
+ # ------------------------------------------------------------------------------------------
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import math
9
+ from typing import Optional, List
10
+ from torch.jit import Final
11
+ from timm.layers import use_fused_attn
12
+ from timm.models.vision_transformer import Attention
13
+ from transformers.models.bert.modeling_bert import BertAttention
14
+ from typing import Optional, Tuple
15
+
16
+ def set_param(curr_mod, name, param=None, mode='update'):
17
+ r"""Refer to https://github.com/Baijiong-Lin/MOML/blob/main/MTL/utils.py"""
18
+ if '.' in name:
19
+ n = name.split('.')
20
+ module_name = n[0]
21
+ rest = '.'.join(n[1:])
22
+ for name, mod in curr_mod.named_children():
23
+ if module_name == name:
24
+ return set_param(mod, rest, param, mode=mode)
25
+ else:
26
+ if mode == 'update':
27
+ delattr(curr_mod, name)
28
+ setattr(curr_mod, name, param)
29
+ elif mode == 'get':
30
+ if hasattr(curr_mod, name):
31
+ p = getattr(curr_mod, name)
32
+ return p
33
+
34
+ class LoRALayer():
35
+ def __init__(
36
+ self,
37
+ r: int,
38
+ lora_alpha: int,
39
+ fan_in_fan_out: bool = False,
40
+ dropout_rate:float = 0,
41
+ ):
42
+ self.r = r
43
+ self.lora_alpha = lora_alpha
44
+ self.dropout_rate = dropout_rate
45
+ if self.r > 0:
46
+ #self.scaling = self.lora_alpha / self.r
47
+ self.scaling = self.lora_alpha/math.sqrt(self.r) #
48
+ # Mark the weight as unmerged
49
+ self.merged = False
50
+ # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
51
+ self.fan_in_fan_out = fan_in_fan_out
52
+ # define params that require LoRA {'param_name': 'lora_name'}
53
+ self.params_with_lora = {}
54
+
55
+ def register_lora_param(self):
56
+ r"""Register LoRA matrix"""
57
+ for param_name, lora_name in self.params_with_lora.items():
58
+ assert len(eval(f'self.{param_name}').size()) == 2
59
+ self.register_parameter(f'{lora_name}_lora_A',
60
+ nn.Parameter(eval(f'self.{param_name}').new_zeros((self.r, eval(f'self.{param_name}').size()[1])))
61
+ )
62
+ self.register_parameter(f'{lora_name}_lora_B',
63
+ nn.Parameter(eval(f'self.{param_name}').new_zeros((eval(f'self.{param_name}').size()[0], self.r)))
64
+ )
65
+
66
+ eval(f'self.{param_name}').requires_grad = False
67
+
68
+ def init_lora_param(self):
69
+ for param_name, lora_name in self.params_with_lora.items():
70
+ if hasattr(self, f'{lora_name}_lora_A'):
71
+ # initialize A the same way as the default for nn.Linear and B to zero
72
+ nn.init.kaiming_uniform_(eval(f'self.{lora_name}_lora_A'), a=math.sqrt(5))
73
+ nn.init.zeros_(eval(f'self.{lora_name}_lora_B'))
74
+
75
+ def transpose(self, w: torch.Tensor):
76
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
77
+
78
+ def merge_BA(self, param_name: str):
79
+ lora_name = self.params_with_lora[param_name]
80
+ return self.transpose((eval(f'self.{lora_name}_lora_B') @ eval(f'self.{lora_name}_lora_A')).view(eval(f'self.{param_name}').shape))
81
+
82
+ def merge_lora_param(self):
83
+ r"""p_new = p + scaling * B @ A and keep differentiable to A and B"""
84
+ for param_name, lora_name in self.params_with_lora.items():
85
+ p = set_param(self, param_name, mode='get')
86
+ # detach() is very important here
87
+
88
+ p_new = p.detach() + self.merge_BA(param_name) * self.scaling
89
+ set_param(self, param_name, param=p_new, mode='update')
90
+
91
+ def add_lora_data(self):
92
+ r"""NOT differentiable"""
93
+ for param_name, lora_name in self.params_with_lora.items():
94
+ eval(f'self.{param_name}').data += self.merge_BA(param_name) * self.scaling
95
+
96
+ def sub_lora_data(self):
97
+ r"""NOT differentiable"""
98
+ for param_name, lora_name in self.params_with_lora.items():
99
+ eval(f'self.{param_name}').data -= self.merge_BA(param_name) * self.scaling
100
+
101
+
102
+ def lora_train(self, mode: bool = True):
103
+ if mode:
104
+ if self.merged and self.r > 0:
105
+ # Make sure that the weights are not merged
106
+ self.sub_lora_data()
107
+ self.merged = False
108
+ else:
109
+ if not self.merged and self.r > 0:
110
+ # Merge the weights and mark it
111
+ self.add_lora_data()
112
+ self.merged = True
113
+
114
+
115
+ class Embedding(nn.Embedding, LoRALayer):
116
+ # LoRA implemented in a Embedding layer
117
+ def __init__(
118
+ self,
119
+ num_embeddings: int,
120
+ embedding_dim: int,
121
+ r: int = 0,
122
+ lora_alpha: int = 1,
123
+ **kwargs
124
+ ):
125
+ nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
126
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
127
+
128
+ self.params_with_lora = {'weight': 'w'}
129
+ if r > 0:
130
+ self.register_lora_param()
131
+ nn.Embedding.reset_parameters(self)
132
+ self.init_lora_param()
133
+
134
+ def init_lora_param(self):
135
+ if hasattr(self, 'w_lora_A'):
136
+ # initialize A the same way as the default for nn.Linear and B to zero
137
+ nn.init.zeros_(self.w_lora_A)
138
+ nn.init.normal_(self.w_lora_B)
139
+
140
+ def train(self, mode: bool = True):
141
+ nn.Embedding.train(self, mode)
142
+ self.lora_train(mode)
143
+
144
+ def forward(self, x: torch.Tensor, **kwargs):
145
+
146
+ if self.r > 0 and not self.merged:
147
+ self.merge_lora_param()
148
+ result = nn.Embedding.forward(self, x, **kwargs)
149
+ self.sub_lora_data()
150
+ return result
151
+ else:
152
+ return nn.Embedding.forward(self, x, **kwargs)
153
+
154
+ class LinearLoRA(nn.Linear, LoRALayer):
155
+ # LoRA implemented in a Linear layer
156
+ def __init__(
157
+ self,
158
+ existing_linear: nn.Linear,
159
+ r: int = 0,
160
+ lora_alpha: int = 1,
161
+ fan_in_fan_out: bool = False,
162
+ dropout_rate = 0.,
163
+ seed: int = 1,
164
+ **kwargs
165
+ ):
166
+ super().__init__(
167
+ in_features=existing_linear.in_features,
168
+ out_features=existing_linear.out_features)
169
+
170
+ self.load_state_dict(existing_linear.state_dict())
171
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, fan_in_fan_out=fan_in_fan_out)
172
+
173
+ # Actual trainable parameters
174
+ self.params_with_lora = {'weight': 'w'}
175
+ if r > 0:
176
+ self.register_lora_param()
177
+ self.init_lora_param()
178
+ self.weight.data = self.transpose(self.weight.data)
179
+ if dropout_rate > 0:
180
+ self.dropout = nn.Dropout(dropout_rate)
181
+ else:
182
+ self.dropout = None
183
+
184
+ def train(self, mode: bool = True):
185
+ super().train(mode)
186
+ self.lora_train(mode)
187
+
188
+
189
+ def forward(self, x: torch.Tensor, **kwargs):
190
+
191
+ if self.dropout is None: # do as before
192
+ if self.r > 0 and not self.merged:
193
+ self.merge_lora_param()
194
+ result = nn.Linear.forward(self, x, **kwargs)
195
+ self.sub_lora_data()
196
+ return result
197
+ else:
198
+ return nn.Linear.forward(self, x, **kwargs)
199
+
200
+ # Compute the original linear transformation
201
+ original_output = nn.Linear.forward(self, x)
202
+
203
+ if self.training and self.dropout.p > 0:
204
+ x = self.dropout(x)
205
+
206
+ if self.r > 0 and not self.merged:
207
+ lora_adjustment = torch.matmul(x,self.merge_BA('weight').transpose(0, 1)) * self.scaling
208
+ result = original_output + lora_adjustment
209
+ else:
210
+ result = original_output
211
+ return result
212
+
213
+ class Conv1d(nn.Conv1d, LoRALayer):
214
+ # LoRA implemented in a Conv1d layer
215
+ def __init__(
216
+ self,
217
+ in_channels: int,
218
+ out_channels: int,
219
+ kernel_size: int,
220
+ r: int = 0,
221
+ lora_alpha: int = 1,
222
+ **kwargs
223
+ ):
224
+ nn.Conv1d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
225
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
226
+
227
+ assert type(kernel_size) is int
228
+ # Actual trainable parameters
229
+ self.params_with_lora = {'weight': 'w'}
230
+ if r > 0:
231
+ self.w_lora_A = nn.Parameter(
232
+ self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
233
+ )
234
+ self.w_lora_B = nn.Parameter(
235
+ self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
236
+ )
237
+ # Freezing the pre-trained weight matrix
238
+ self.weight.requires_grad = False
239
+ nn.Conv1d.reset_parameters(self)
240
+ self.init_lora_param()
241
+
242
+ def train(self, mode: bool = True):
243
+ nn.Conv1d.train(self, mode)
244
+ self.lora_train(mode)
245
+
246
+ def forward(self, x: torch.Tensor, **kwargs):
247
+
248
+ if self.r > 0 and not self.merged:
249
+ self.merge_lora_param()
250
+ result = nn.Conv1d.forward(self, x, **kwargs)
251
+ self.sub_lora_data()
252
+ return result
253
+ else:
254
+ return nn.Conv1d.forward(self, x, **kwargs)
255
+
256
+ class Conv2d(nn.Conv2d, LoRALayer):
257
+ # LoRA implemented in a Conv2d layer
258
+ def __init__(
259
+ self,
260
+ in_channels: int,
261
+ out_channels: int,
262
+ kernel_size: int,
263
+ r: int = 0,
264
+ lora_alpha: int = 1,
265
+ **kwargs
266
+ ):
267
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
268
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
269
+
270
+ assert type(kernel_size) is int
271
+ # Actual trainable parameters
272
+ self.params_with_lora = {'weight': 'w'}
273
+ if r > 0:
274
+ self.w_lora_A = nn.Parameter(
275
+ self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
276
+ )
277
+ self.w_lora_B = nn.Parameter(
278
+ self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
279
+ )
280
+ # Freezing the pre-trained weight matrix
281
+ self.weight.requires_grad = False
282
+ nn.Conv2d.reset_parameters(self)
283
+ self.init_lora_param()
284
+
285
+ def train(self, mode: bool = True):
286
+ nn.Conv2d.train(self, mode)
287
+ self.lora_train(mode)
288
+
289
+ def forward(self, x: torch.Tensor, **kwargs):
290
+
291
+ if self.r > 0 and not self.merged:
292
+ self.merge_lora_param()
293
+ result = nn.Conv2d.forward(self, x, **kwargs)
294
+ self.sub_lora_data()
295
+ return result
296
+ else:
297
+ return nn.Conv2d.forward(self, x, **kwargs)
298
+
299
+ class Conv3d(nn.Conv3d, LoRALayer):
300
+ # LoRA implemented in a Conv3d layer
301
+ def __init__(
302
+ self,
303
+ in_channels: int,
304
+ out_channels: int,
305
+ kernel_size: int,
306
+ r: int = 0,
307
+ lora_alpha: int = 1,
308
+ **kwargs
309
+ ):
310
+ nn.Conv3d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
311
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
312
+
313
+ assert type(kernel_size) is int
314
+ # Actual trainable parameters
315
+ self.params_with_lora = {'weight': 'w'}
316
+ if r > 0:
317
+ self.w_lora_A = nn.Parameter(
318
+ self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
319
+ )
320
+ self.w_lora_B = nn.Parameter(
321
+ self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
322
+ )
323
+ # Freezing the pre-trained weight matrix
324
+ self.weight.requires_grad = False
325
+ nn.Conv3d.reset_parameters(self)
326
+ self.init_lora_param()
327
+
328
+ def train(self, mode: bool = True):
329
+ nn.Conv3d.train(self, mode)
330
+ self.lora_train(mode)
331
+
332
+ def forward(self, x: torch.Tensor, **kwargs):
333
+
334
+ if self.r > 0 and not self.merged:
335
+ self.merge_lora_param()
336
+ result = nn.Conv3d.forward(self, x, **kwargs)
337
+ self.sub_lora_data()
338
+ return result
339
+ else:
340
+ return nn.Conv3d.forward(self, x, **kwargs)
341
+
342
+
343
+ class PlainMultiheadAttentionLoRA(nn.Module):
344
+ def __init__(
345
+ self,
346
+ existing_mha: nn.MultiheadAttention,
347
+ enable_lora: list = ['q', 'k', 'v', 'o'],
348
+ r: int = 0,
349
+ lora_alpha: int = 1,
350
+ dropout_rate:float = 0.,
351
+ **kwargs
352
+ ):
353
+ super().__init__()
354
+
355
+ self.dropout = 0 # this module is not used to retrain the main block
356
+ self.embed_dim = existing_mha.embed_dim
357
+ self.kdim = existing_mha.kdim
358
+ self.vdim = existing_mha.vdim
359
+ self._qkv_same_embed_dim = existing_mha._qkv_same_embed_dim
360
+ self.num_heads = existing_mha.num_heads
361
+ self.batch_first = existing_mha.batch_first
362
+ self.head_dim = existing_mha.head_dim
363
+ #self.qkv = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=existing_mha.in_proj_bias is not None)
364
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.in_proj_bias is not None)
365
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.in_proj_bias is not None)
366
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.in_proj_bias is not None)
367
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.out_proj.bias is not None)
368
+
369
+ # Initialize parameters
370
+ with torch.no_grad():
371
+
372
+ # Extract the existing weights and biases
373
+ existing_weight = existing_mha.in_proj_weight.data
374
+ existing_bias = existing_mha.in_proj_bias.data if existing_mha.in_proj_bias is not None else None
375
+
376
+ # Initialize q_proj
377
+ self.q_proj.weight.data.copy_(existing_weight[:self.embed_dim, :])
378
+ if existing_bias is not None:
379
+ self.q_proj.bias.data.copy_(existing_bias[:self.embed_dim])
380
+
381
+ # Initialize k_proj
382
+ self.k_proj.weight.data.copy_(existing_weight[self.embed_dim:2*self.embed_dim, :])
383
+ if existing_bias is not None:
384
+ self.k_proj.bias.data.copy_(existing_bias[self.embed_dim:2*self.embed_dim])
385
+
386
+ # Initialize v_proj
387
+ self.v_proj.weight.data.copy_(existing_weight[2*self.embed_dim:, :])
388
+ if existing_bias is not None:
389
+ self.v_proj.bias.data.copy_(existing_bias[2*self.embed_dim:])
390
+
391
+ # Initialize proj
392
+ self.proj.weight.data.copy_(existing_mha.out_proj.weight.data)
393
+ if self.proj.bias is not None:
394
+ self.proj.bias.data.copy_(existing_mha.out_proj.bias.data)
395
+
396
+ self.scaled_dot_product_attention = F.scaled_dot_product_attention
397
+
398
+
399
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, dropout_rate=dropout_rate)
400
+
401
+ # Init qkv as a new lora linear layer
402
+ for item in enable_lora:
403
+ if item == 'q':
404
+ self.q_proj = LinearLoRA(self.q_proj,
405
+ r=r,
406
+ lora_alpha=lora_alpha,
407
+ fan_in_fan_out=False,
408
+ dropout_rate = dropout_rate)
409
+ elif item == 'k':
410
+ self.k_proj = LinearLoRA(self.k_proj,
411
+ r=r,
412
+ lora_alpha=lora_alpha,
413
+ fan_in_fan_out=False,
414
+ dropout_rate = dropout_rate)
415
+ elif item == 'v':
416
+ self.v_proj = LinearLoRA(self.v_proj,
417
+ r=r,
418
+ lora_alpha=lora_alpha,
419
+ fan_in_fan_out=False,
420
+ dropout_rate = dropout_rate)
421
+ elif item == 'o':
422
+ self.proj = LinearLoRA(self.proj,
423
+ r=r,
424
+ lora_alpha=lora_alpha,
425
+ fan_in_fan_out=False,
426
+ dropout_rate = dropout_rate)
427
+
428
+ def forward_module(
429
+ self,
430
+ query,
431
+ key,
432
+ value,
433
+ key_padding_mask=None,
434
+ need_weights=True,
435
+ attn_mask=None,
436
+ average_attn_weights=True,
437
+ is_causal=False):
438
+
439
+ if attn_mask is not None and is_causal:
440
+ raise AssertionError("Only allow causal mask or attn_mask")
441
+ is_batched = query.dim() == 3
442
+ key_padding_mask = F._canonical_mask(
443
+ mask=key_padding_mask,
444
+ mask_name="key_padding_mask",
445
+ other_type=F._none_or_dtype(attn_mask),
446
+ other_name="attn_mask",
447
+ target_type=query.dtype
448
+ )
449
+
450
+ if self.batch_first and is_batched:
451
+ if key is value:
452
+ if query is key:
453
+ query = key = value = query.transpose(1, 0)
454
+ else:
455
+ query, key = [x.transpose(1, 0) for x in (query, key)]
456
+ value = key
457
+ else:
458
+ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
459
+
460
+ tgt_len, bsz, embed_dim = query.shape
461
+ src_len, _, _ = key.shape
462
+ """
463
+ E = query.size(-1)
464
+ qkv = self.qkv(query)
465
+ qkv = qkv.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
466
+ q, k, v = qkv[0], qkv[1], qkv[2]
467
+ """
468
+
469
+ q = self.q_proj(query)
470
+ k = self.k_proj(key)
471
+ v = self.v_proj(value)
472
+
473
+ attn_mask = F._canonical_mask(
474
+ mask=attn_mask,
475
+ mask_name="attn_mask",
476
+ other_type=F._none_or_dtype(key_padding_mask),
477
+ other_name="key_padding_mask",
478
+ target_type=q.dtype,
479
+ check_other=False,
480
+ )
481
+
482
+ if attn_mask is not None:
483
+ # ensure attn_mask's dim is 3
484
+ if attn_mask.dim() == 2:
485
+ correct_2d_size = (tgt_len, src_len)
486
+ if attn_mask.shape != correct_2d_size:
487
+ raise RuntimeError(
488
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
489
+ attn_mask = attn_mask.unsqueeze(0)
490
+ elif attn_mask.dim() == 3:
491
+ correct_3d_size = (bsz * self.num_heads, tgt_len, src_len)
492
+ if attn_mask.shape != correct_3d_size:
493
+ raise RuntimeError(
494
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
495
+ else:
496
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
497
+
498
+ if attn_mask is not None:
499
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
500
+ attn_mask = attn_mask.unsqueeze(0)
501
+ else:
502
+ attn_mask = attn_mask.view(bsz, self.num_heads, -1, src_len)
503
+
504
+ dropout_p = self.dropout if self.training else 0.
505
+
506
+ q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
507
+ k = k.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
508
+ v = v.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
509
+ src_len = k.size(1)
510
+ q = q.view(bsz, self.num_heads, tgt_len, self.head_dim)
511
+ k = k.view(bsz, self.num_heads, src_len, self.head_dim)
512
+ v = v.view(bsz, self.num_heads, src_len, self.head_dim)
513
+
514
+ attn_output = self.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
515
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
516
+ attn_output = self.proj(attn_output)
517
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
518
+ if self.batch_first and is_batched:
519
+ return attn_output.transpose(1, 0), None
520
+ return attn_output, None
521
+
522
+ def train(self, mode: bool = True):
523
+ super().train(mode)
524
+ #self.lora_train(mode)
525
+
526
+ def forward(self,
527
+ query: torch.Tensor,
528
+ key: torch.Tensor,
529
+ value: torch.Tensor,
530
+ **kwargs):
531
+
532
+
533
+ return self.forward_module(query, key, value, **kwargs)
534
+
535
+ class AttentionLoRA(nn.Module):
536
+ fused_attn: Final[bool]
537
+
538
+ def __init__(
539
+ self,
540
+ existing_mha: Attention,
541
+ enable_lora: list = ['q', 'k', 'v', 'o'],
542
+ r: int = 0,
543
+ lora_alpha: int = 1,
544
+ dropout_rate: float = 0.,
545
+ seed: int = 1,
546
+ ) -> None:
547
+ super().__init__()
548
+
549
+ torch.manual_seed(seed)
550
+ torch.cuda.manual_seed(seed)
551
+ self.embed_dim = existing_mha.proj.in_features
552
+ self.num_heads = existing_mha.num_heads
553
+ self.head_dim = existing_mha.head_dim
554
+ assert self.embed_dim % self.num_heads == 0, 'dim should be divisible by num_heads'
555
+ self.scale = self.head_dim ** -0.5
556
+ self.fused_attn = use_fused_attn()
557
+ self.dropout = 0
558
+ self.q_norm = existing_mha.q_norm
559
+ self.k_norm = existing_mha.k_norm
560
+ self.attn_drop = nn.Dropout(self.dropout)
561
+ self.proj_drop = nn.Dropout(self.dropout)
562
+ self.r = r
563
+ self.lora_alpha = lora_alpha
564
+ self.dropout_rate = dropout_rate
565
+ self.enable_lora = enable_lora
566
+ self.seed = seed
567
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, dropout_rate=dropout_rate)
568
+
569
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.qkv.bias is not None)
570
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.qkv.bias is not None)
571
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.qkv.bias is not None)
572
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.proj.bias is not None)
573
+
574
+ # Initialize parameters
575
+ with torch.no_grad():
576
+ existing_weight = existing_mha.qkv.weight.data
577
+ existing_bias = existing_mha.qkv.bias.data
578
+
579
+ self.q_proj.weight.data.copy_(existing_weight[:self.embed_dim, :])
580
+ if existing_bias is not None:
581
+ self.q_proj.bias.data.copy_(existing_bias[:self.embed_dim])
582
+
583
+ self.k_proj.weight.data.copy_(existing_weight[self.embed_dim:2*self.embed_dim, :])
584
+ if existing_bias is not None:
585
+ self.k_proj.bias.data.copy_(existing_bias[self.embed_dim:2*self.embed_dim])
586
+
587
+ self.v_proj.weight.data.copy_(existing_weight[2*self.embed_dim:, :])
588
+ if existing_bias is not None:
589
+ self.v_proj.bias.data.copy_(existing_bias[2*self.embed_dim:])
590
+
591
+ self.proj.weight.data.copy_(existing_mha.proj.weight.data)
592
+ if self.proj.bias is not None:
593
+ self.proj.bias.data.copy_(existing_mha.proj.bias.data)
594
+
595
+ self.q_proj, self.k_proj, self.v_proj, self.proj = self.inject_lora(self.q_proj, self.k_proj, self.v_proj, self.proj)
596
+
597
+
598
+ def inject_lora(self, q, k, v, proj):
599
+ for item in self.enable_lora:
600
+ if item == 'q':
601
+ q = LinearLoRA(q,
602
+ r=self.r,
603
+ lora_alpha=self.lora_alpha,
604
+ fan_in_fan_out=False,
605
+ dropout_rate = self.dropout_rate,
606
+ seed=self.seed)
607
+ elif item == 'k':
608
+ k = LinearLoRA(k,
609
+ r=self.r,
610
+ lora_alpha=self.lora_alpha,
611
+ fan_in_fan_out=False,
612
+ dropout_rate = self.dropout_rate,
613
+ seed=self.seed)
614
+ elif item == 'v':
615
+ v = LinearLoRA(v,
616
+ r=self.r,
617
+ lora_alpha=self.lora_alpha,
618
+ fan_in_fan_out=False,
619
+ dropout_rate = self.dropout_rate,
620
+ seed=self.seed)
621
+ elif item == 'o':
622
+ proj = LinearLoRA(proj,
623
+ r=self.r,
624
+ lora_alpha=self.lora_alpha,
625
+ fan_in_fan_out=False,
626
+ dropout_rate = self.dropout_rate,
627
+ seed=self.seed)
628
+
629
+ return q, k, v, proj
630
+
631
+ def forward(self, x: torch.Tensor, return_attn_scores=False) -> torch.Tensor:
632
+ B, N, C = x.shape
633
+ q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
634
+ k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
635
+ v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
636
+ q, k = self.q_norm(q), self.k_norm(k)
637
+
638
+ if return_attn_scores:
639
+ q = q * self.scale
640
+ attn_scores = q @ k.transpose(-2, -1)
641
+ attn = attn_scores.softmax(dim=-1)
642
+ attn = self.attn_drop(attn)
643
+ x = attn @ v
644
+ x = x.transpose(1, 2).reshape(B, N, C)
645
+ x = self.proj(x)
646
+ x = self.proj_drop(x)
647
+
648
+ return (x, attn_scores)
649
+
650
+ if self.fused_attn:
651
+ x = F.scaled_dot_product_attention(
652
+ q, k, v,
653
+ dropout_p=self.attn_drop.p if self.training else 0.,
654
+ )
655
+ else:
656
+ q = q * self.scale
657
+ attn = q @ k.transpose(-2, -1)
658
+ attn = attn.softmax(dim=-1)
659
+ attn = self.attn_drop(attn)
660
+ x = attn @ v
661
+
662
+ x = x.transpose(1, 2).reshape(B, N, C)
663
+ x = self.proj(x)
664
+ x = self.proj_drop(x)
665
+ return x
666
+
667
+ class BertAttentionLoRA(nn.Module):
668
+ def __init__(self,
669
+ existing_mha: BertAttention,
670
+ enable_lora: list = ['q', 'k', 'v', 'o'],
671
+ r: int = 0,
672
+ lora_alpha: int = 1,
673
+ dropout_rate: float = 0.,
674
+ seed:int = 1,):
675
+ super().__init__()
676
+
677
+ torch.manual_seed(seed)
678
+ torch.cuda.manual_seed(seed)
679
+ self.self_attn = existing_mha.self
680
+ self.output = existing_mha.output
681
+ self.num_attention_heads = self.self_attn.num_attention_heads
682
+ self.attention_head_size = self.self_attn.attention_head_size
683
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
684
+ self.hidden_size = self.self_attn.query.in_features
685
+
686
+ self.q_proj = nn.Linear(self.hidden_size, self.all_head_size)
687
+ self.k_proj = nn.Linear(self.hidden_size, self.all_head_size)
688
+ self.v_proj = nn.Linear(self.hidden_size, self.all_head_size)
689
+ self.proj = nn.Linear(self.output.dense.in_features, self.output.dense.in_features)
690
+ self.LayerNorm = self.output.LayerNorm
691
+ self.dropout = nn.Dropout(0)
692
+
693
+ self.r = r
694
+ self.lora_alpha = lora_alpha
695
+ self.dropout_rate = dropout_rate
696
+ self.enable_lora = enable_lora
697
+ self.seed = seed
698
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, dropout_rate=dropout_rate)
699
+
700
+ # Initialize parameters
701
+ with torch.no_grad():
702
+
703
+ self.q_proj.weight.data.copy_(self.self_attn.query.weight.data)
704
+ if self.self_attn.query.bias.data is not None:
705
+ self.q_proj.bias.data.copy_(self.self_attn.query.bias.data)
706
+
707
+ self.k_proj.weight.data.copy_(self.self_attn.key.weight.data)
708
+ if self.self_attn.key.bias.data is not None:
709
+ self.k_proj.bias.data.copy_(self.self_attn.key.bias.data)
710
+
711
+ self.v_proj.weight.data.copy_(self.self_attn.value.weight.data)
712
+ if self.self_attn.value.bias.data is not None:
713
+ self.v_proj.bias.data.copy_(self.self_attn.value.bias.data)
714
+
715
+ self.proj.weight.data.copy_(self.output.dense.weight.data)
716
+ if self.output.dense.bias.data is not None:
717
+ self.proj.bias.data.copy_(self.output.dense.bias.data)
718
+
719
+ self.q_proj, self.k_proj, self.v_proj, self.proj = self.inject_lora(self.q_proj, self.k_proj, self.v_proj, self.proj)
720
+
721
+ self.position_embedding_type = self.self_attn.position_embedding_type
722
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
723
+ self.max_position_embeddings = self.self_attn.max_position_embeddings
724
+ self.distance_embedding = nn.Embedding(2 * self.self_attn.max_position_embeddings - 1, self.attention_head_size)
725
+
726
+ self.is_decoder = self.self_attn.is_decoder
727
+
728
+ def inject_lora(self, q, k, v, proj):
729
+ for item in self.enable_lora:
730
+ if item == 'q':
731
+ q = LinearLoRA(q,
732
+ r=self.r,
733
+ lora_alpha=self.lora_alpha,
734
+ fan_in_fan_out=False,
735
+ dropout_rate = self.dropout_rate,
736
+ seed=self.seed)
737
+ elif item == 'k':
738
+ k = LinearLoRA(k,
739
+ r=self.r,
740
+ lora_alpha=self.lora_alpha,
741
+ fan_in_fan_out=False,
742
+ dropout_rate = self.dropout_rate,
743
+ seed=self.seed)
744
+ elif item == 'v':
745
+ v = LinearLoRA(v,
746
+ r=self.r,
747
+ lora_alpha=self.lora_alpha,
748
+ fan_in_fan_out=False,
749
+ dropout_rate = self.dropout_rate,
750
+ seed=self.seed)
751
+ elif item == 'o':
752
+ proj = LinearLoRA(proj,
753
+ r=self.r,
754
+ lora_alpha=self.lora_alpha,
755
+ fan_in_fan_out=False,
756
+ dropout_rate = self.dropout_rate,
757
+ seed=self.seed)
758
+
759
+ return q, k, v, proj
760
+
761
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
762
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
763
+ x = x.view(new_x_shape)
764
+ return x.permute(0, 2, 1, 3)
765
+
766
+ def forward(
767
+ self,
768
+ hidden_states: torch.Tensor,
769
+ attention_mask: Optional[torch.FloatTensor] = None,
770
+ head_mask: Optional[torch.FloatTensor] = None,
771
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
772
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
773
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
774
+ output_attentions: Optional[bool] = False,
775
+ ) -> Tuple[torch.Tensor]:
776
+ mixed_query_layer = self.q_proj(hidden_states)
777
+
778
+ # If this is instantiated as a cross-attention module, the keys
779
+ # and values come from an encoder; the attention mask needs to be
780
+ # such that the encoder's padding tokens are not attended to.
781
+ is_cross_attention = encoder_hidden_states is not None
782
+
783
+ if is_cross_attention and past_key_value is not None:
784
+ # reuse k,v, cross_attentions
785
+ key_layer = past_key_value[0]
786
+ value_layer = past_key_value[1]
787
+ attention_mask = encoder_attention_mask
788
+ elif is_cross_attention:
789
+ key_layer = self.transpose_for_scores(self.k_proj(encoder_hidden_states))
790
+ value_layer = self.transpose_for_scores(self.v_proj(encoder_hidden_states))
791
+ attention_mask = encoder_attention_mask
792
+ elif past_key_value is not None:
793
+ key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
794
+ value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
795
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
796
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
797
+ else:
798
+ key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
799
+ value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
800
+
801
+ query_layer = self.transpose_for_scores(mixed_query_layer)
802
+
803
+ use_cache = past_key_value is not None
804
+ if self.is_decoder:
805
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
806
+ # Further calls to cross_attention layer can then reuse all cross-attention
807
+ # key/value_states (first "if" case)
808
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
809
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
810
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
811
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
812
+ past_key_value = (key_layer, value_layer)
813
+
814
+ # Take the dot product between "query" and "key" to get the raw attention scores.
815
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
816
+
817
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
818
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
819
+ if use_cache:
820
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
821
+ -1, 1
822
+ )
823
+ else:
824
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
825
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
826
+ distance = position_ids_l - position_ids_r
827
+
828
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
829
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
830
+
831
+ if self.position_embedding_type == "relative_key":
832
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
833
+ attention_scores = attention_scores + relative_position_scores
834
+ elif self.position_embedding_type == "relative_key_query":
835
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
836
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
837
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
838
+
839
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
840
+ if attention_mask is not None:
841
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
842
+ attention_scores = attention_scores + attention_mask
843
+
844
+ # Normalize the attention scores to probabilities.
845
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
846
+
847
+ # This is actually dropping out entire tokens to attend to, which might
848
+ # seem a bit unusual, but is taken from the original Transformer paper.
849
+ attention_probs = self.dropout(attention_probs)
850
+
851
+ # Mask heads if we want to
852
+ if head_mask is not None:
853
+ attention_probs = attention_probs * head_mask
854
+
855
+ context_layer = torch.matmul(attention_probs, value_layer)
856
+
857
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
858
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
859
+ context_layer = context_layer.view(new_context_layer_shape)
860
+
861
+ self_attn_outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
862
+
863
+ if self.is_decoder:
864
+ self_attn_outputs = self_attn_outputs + (past_key_value,)
865
+
866
+ # attention_output = self.output(self_outputs[0], hidden_states)
867
+ self_outputs = self.proj(self_attn_outputs[0])
868
+ attention_output = self.LayerNorm(self_outputs + hidden_states)
869
+ outputs = (attention_output,) + self_attn_outputs[1:] # add attentions if we output them
870
+ return outputs
871
+
872
+
873
+ class MergedLinear(nn.Linear, LoRALayer):
874
+ # LoRA implemented in a dense layer
875
+ def __init__(
876
+ self,
877
+ in_features: int,
878
+ out_features: int,
879
+ r: int = 0,
880
+ lora_alpha: int = 1,
881
+ enable_lora: List[bool] = [False],
882
+ fan_in_fan_out: bool = False,
883
+ **kwargs
884
+ ):
885
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
886
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
887
+
888
+ assert out_features % len(enable_lora) == 0, \
889
+ 'The length of enable_lora must divide out_features'
890
+ self.enable_lora = enable_lora
891
+ # Actual trainable parameters
892
+ self.params_with_lora = {'weight': 'w'}
893
+ if r > 0 and any(enable_lora):
894
+ self.w_lora_A = nn.Parameter(
895
+ self.weight.new_zeros((r * sum(enable_lora), in_features)))
896
+ self.w_lora_B = nn.Parameter(
897
+ self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
898
+ ) # weights for Conv1D with groups=sum(enable_lora)
899
+ # Freezing the pre-trained weight matrix
900
+ self.weight.requires_grad = False
901
+ # Compute the indices
902
+ self.lora_ind = self.weight.new_zeros(
903
+ (out_features, ), dtype=torch.bool
904
+ ).view(len(enable_lora), -1)
905
+ self.lora_ind[enable_lora, :] = True
906
+ self.lora_ind = self.lora_ind.view(-1)
907
+ nn.Linear.reset_parameters(self)
908
+ self.init_lora_param()
909
+ self.weight.data = self.transpose(self.weight.data)
910
+
911
+ def zero_pad(self, x):
912
+ result = x.new_zeros((len(self.lora_ind), *x.shape[1:]))
913
+ result[self.lora_ind] = x
914
+ return result
915
+
916
+ def merge_BA(self, param_name: str):
917
+ lora_name = self.params_with_lora[param_name]
918
+ delta_w = F.conv1d(
919
+ eval(f'self.{lora_name}_lora_A').unsqueeze(0),
920
+ eval(f'self.{lora_name}_lora_B').unsqueeze(-1),
921
+ groups=sum(self.enable_lora)
922
+ ).squeeze(0)
923
+ return self.transpose(self.zero_pad(delta_w))
924
+
925
+ def train(self, mode: bool = True):
926
+ nn.Linear.train(self, mode)
927
+ self.lora_train(mode)
928
+
929
+ def forward(self, x: torch.Tensor, **kwargs):
930
+
931
+ if self.r > 0 and not self.merged:
932
+ self.merge_lora_param()
933
+ result = nn.Linear.forward(self, x, **kwargs)
934
+ self.sub_lora_data()
935
+ return result
936
+ else:
937
+ return nn.Linear.forward(self, x, **kwargs)
loralib/utils.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Dict
6
+
7
+ from .layers import LoRALayer, AttentionLoRA, BertAttentionLoRA
8
+ from timm.models.vision_transformer import Attention
9
+ from transformers.models.bert.modeling_bert import BertAttention
10
+
11
+
12
+ INDEX_POSITIONS_TEXT = {
13
+ 'top1': [11],
14
+ 'top2': [10, 11],
15
+ 'top3': [9, 10, 11],
16
+ 'bottom': [0, 1, 2, 3],
17
+ 'mid': [4, 5, 6, 7],
18
+ 'up': [8, 9, 10, 11],
19
+ 'half-up': [6, 7, 8, 9, 10, 11],
20
+ 'half-bottom': [0, 1, 2, 3, 4, 5],
21
+ 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}
22
+
23
+
24
+ INDEX_POSITIONS_VISION = {
25
+ 'top': [11],
26
+ 'top3': [9, 10, 11],
27
+ 'bottom': [0, 1, 2, 3],
28
+ 'mid': [4, 5, 6, 7],
29
+ 'up': [8, 9, 10, 11],
30
+ 'half-up': [6, 7, 8, 9, 10, 11],
31
+ 'half-bottom': [0, 1, 2, 3, 4, 5],
32
+ 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
33
+ }
34
+
35
+
36
+ def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
37
+ for n, p in model.named_parameters():
38
+ if 'lora_' not in n:
39
+ p.requires_grad = False
40
+ if bias == 'none':
41
+ return
42
+ elif bias == 'all':
43
+ for n, p in model.named_parameters():
44
+ if 'bias' in n:
45
+ p.requires_grad = True
46
+ elif bias == 'lora_only':
47
+ for m in model.modules():
48
+ if isinstance(m, LoRALayer) and \
49
+ hasattr(m, 'bias') and \
50
+ m.bias is not None:
51
+ m.bias.requires_grad = True
52
+ else:
53
+ raise NotImplementedError
54
+
55
+
56
+ def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
57
+ my_state_dict = model.state_dict()
58
+ if bias == 'none':
59
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
60
+ elif bias == 'all':
61
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
62
+ elif bias == 'lora_only':
63
+ to_return = {}
64
+ for k in my_state_dict:
65
+ if 'lora_' in k:
66
+ to_return[k] = my_state_dict[k]
67
+ bias_name = k.split('lora_')[0]+'bias'
68
+ if bias_name in my_state_dict:
69
+ to_return[bias_name] = my_state_dict[bias_name]
70
+ return to_return
71
+ else:
72
+ raise NotImplementedError
73
+
74
+
75
+ def get_lora_parameters(model, bias='none'):
76
+ params = []
77
+ for name, param in model.named_parameters():
78
+ if bias == 'none':
79
+ if 'lora_' in name:
80
+ params.append(param)
81
+ elif bias == 'all':
82
+ if 'lora_' in name or 'bias' in name:
83
+ params.append(param)
84
+ elif bias == 'lora_only':
85
+ if 'lora_' in name:
86
+ params.append(param)
87
+ bias_name = name.split('lora_')[0] + 'bias'
88
+ if bias_name in model.state_dict():
89
+ bias_param = dict(model.named_parameters())[bias_name]
90
+ params.append(bias_param)
91
+ else:
92
+ raise NotImplementedError
93
+ return params
94
+
95
+
96
+ def apply_lora(args, clip_model):
97
+ list_lora_layers = []
98
+ indices = INDEX_POSITIONS_TEXT[args.position]
99
+ text_encoder = clip_model.text.transformer.encoder
100
+ for i, block in enumerate(text_encoder.layer):
101
+ if i in indices:
102
+ for name, submodule in block.named_children():
103
+ if isinstance(submodule, BertAttention):
104
+ new_multi_head_lora = BertAttentionLoRA(
105
+ submodule, enable_lora=args.params, r=args.r, lora_alpha=args.alpha, dropout_rate=args.dropout_rate, seed=args.seed)
106
+ setattr(block, name, new_multi_head_lora)
107
+ list_lora_layers.append(new_multi_head_lora)
108
+
109
+ indices = INDEX_POSITIONS_VISION[args.position]
110
+ vision_encoder = clip_model.visual.trunk
111
+ for i, block in enumerate(vision_encoder.blocks):
112
+ if i in indices:
113
+ for name, submodule in block.named_children():
114
+ if isinstance(submodule, Attention):
115
+ new_multi_head_lora = AttentionLoRA(
116
+ submodule, enable_lora=args.params, r=args.r, lora_alpha=args.alpha, dropout_rate=args.dropout_rate, seed=args.seed)
117
+ setattr(block, name, new_multi_head_lora)
118
+ list_lora_layers.append(new_multi_head_lora)
119
+ return list_lora_layers
120
+
121
+
122
+ def save_lora(args, list_lora_layers, loss_fn, msg, save_dir):
123
+ weights = {}
124
+ for i, layer in enumerate(list_lora_layers):
125
+ layer_weights = {}
126
+ if 'q' in args.params:
127
+ layer_weights['q_proj'] = {
128
+ 'w_lora_A': layer.q_proj.w_lora_A.data,
129
+ 'w_lora_B': layer.q_proj.w_lora_B.data
130
+ }
131
+ if 'k' in args.params:
132
+ layer_weights['k_proj'] = {
133
+ 'w_lora_A': layer.k_proj.w_lora_A.data,
134
+ 'w_lora_B': layer.k_proj.w_lora_B.data
135
+ }
136
+ if 'v' in args.params:
137
+ layer_weights['v_proj'] = {
138
+ 'w_lora_A': layer.v_proj.w_lora_A.data,
139
+ 'w_lora_B': layer.v_proj.w_lora_B.data
140
+ }
141
+ if 'o' in args.params:
142
+ layer_weights['proj'] = {
143
+ 'w_lora_A': layer.proj.w_lora_A.data,
144
+ 'w_lora_B': layer.proj.w_lora_B.data
145
+ }
146
+
147
+ weights[f'layer_{i}'] = layer_weights
148
+
149
+ if args.loss_type == 'clip_loss_ace_hgnn':
150
+ weights['img_edge_adapter'] = loss_fn.img_edge_adapter.state_dict()
151
+ weights['img_node_adapter'] = loss_fn.img_node_adapter.state_dict()
152
+ weights['text_edge_adapter'] = loss_fn.text_edge_adapter.state_dict()
153
+ weights['text_node_adapter'] = loss_fn.text_node_adapter.state_dict()
154
+
155
+ if args.learnable_logit_scale:
156
+ weights['logit_scale'] = loss_fn.logit_scale.data.cpu()
157
+
158
+ metadata = {
159
+ 'r': args.r,
160
+ 'topk': args.topk,
161
+ 'params': args.params,
162
+ 'position': args.position,
163
+ 'loss_type' : args.loss_type,
164
+ }
165
+
166
+ save_data = {
167
+ 'weights': weights,
168
+ 'metadata': metadata
169
+ }
170
+
171
+ save_path = f'{save_dir}/{args.filename}_{msg}.pt'
172
+ torch.save(save_data, save_path)
173
+ print(f'LoRA weights saved to {save_path}')
174
+
175
+ def load_model(args, list_lora_layers, device, loss_fn=None):
176
+
177
+ if not os.path.exists(args.load_path):
178
+ raise FileNotFoundError(f'File {args.load_path} does not exist.')
179
+
180
+ loaded_data = torch.load(args.load_path, map_location=device)
181
+
182
+ weights = loaded_data['weights']
183
+ for i, layer in enumerate(list_lora_layers):
184
+ layer_weights = weights[f'layer_{i}']
185
+ if 'q' in args.params and 'q_proj' in layer_weights:
186
+ layer.q_proj.w_lora_A.data.copy_(
187
+ layer_weights['q_proj']['w_lora_A'])
188
+ layer.q_proj.w_lora_B.data.copy_(
189
+ layer_weights['q_proj']['w_lora_B'])
190
+ if 'k' in args.params and 'k_proj' in layer_weights:
191
+ layer.k_proj.w_lora_A.data.copy_(
192
+ layer_weights['k_proj']['w_lora_A'])
193
+ layer.k_proj.w_lora_B.data.copy_(
194
+ layer_weights['k_proj']['w_lora_B'])
195
+ if 'v' in args.params and 'v_proj' in layer_weights:
196
+ layer.v_proj.w_lora_A.data.copy_(
197
+ layer_weights['v_proj']['w_lora_A'])
198
+ layer.v_proj.w_lora_B.data.copy_(
199
+ layer_weights['v_proj']['w_lora_B'])
200
+ if 'o' in args.params and 'proj' in layer_weights:
201
+ layer.proj.w_lora_A.data.copy_(layer_weights['proj']['w_lora_A'])
202
+ layer.proj.w_lora_B.data.copy_(layer_weights['proj']['w_lora_B'])
203
+
204
+ if args.loss_type == 'clip_loss_ace_hgnn':
205
+ loss_fn.img_edge_adapter.load_state_dict(weights['img_edge_adapter'])
206
+ loss_fn.img_node_adapter.load_state_dict(weights['img_node_adapter'])
207
+ loss_fn.text_edge_adapter.load_state_dict(weights['text_edge_adapter'])
208
+ loss_fn.text_node_adapter.load_state_dict(weights['text_node_adapter'])
209
+
210
+ if args.learnable_logit_scale:
211
+ loss_fn.logit_scale.data.copy_(weights['logit_scale'])
212
+
213
+ print(f'LoRA weights loaded from {args.load_path}')
loss.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from run_utils import set_random_seed
8
+
9
+ class Identity(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def forward(self, x):
14
+ return x
15
+
16
+ class CLIPLoss(nn.Module):
17
+ def __init__(self, args, logit_scale):
18
+ super(CLIPLoss, self).__init__()
19
+ self.args = args
20
+ if args.learnable_logit_scale:
21
+ self.logit_scale = nn.Parameter(logit_scale.clone().detach())
22
+ else:
23
+ self.register_buffer('logit_scale', logit_scale.clone().detach())
24
+
25
+ def forward(self, image_features, text_features, merged_df=None, indices=None):
26
+
27
+ device = image_features.device
28
+ batch_size, feature_dim = image_features.size()
29
+ labels = torch.arange(batch_size, device=device, dtype=torch.long)
30
+
31
+ logits_per_image = self.logit_scale * image_features @ text_features.t()
32
+
33
+ logits_per_text = logits_per_image.T
34
+ if merged_df is not None:
35
+ compare_matrix = merged_df.iloc[indices, 2:].to_numpy()
36
+ vector_similarity_matrix = np.ones((compare_matrix.shape[0], compare_matrix.shape[0]), dtype=np.int32)
37
+ comparison = (compare_matrix[:, None, :] == compare_matrix[None, :, :]).all(axis=2)
38
+ vector_similarity_matrix[comparison] = 0
39
+ np.fill_diagonal(vector_similarity_matrix, 1)
40
+ vector_similarity_matrix = torch.from_numpy(vector_similarity_matrix).bool().to(device)
41
+ masked_logits_per_image = logits_per_image.masked_fill(~vector_similarity_matrix, float('-inf'))
42
+ masked_logits_per_text = logits_per_text.masked_fill(~vector_similarity_matrix.T, float('-inf'))
43
+ loss = (F.cross_entropy(masked_logits_per_image, labels) + F.cross_entropy(masked_logits_per_text, labels)) / 2
44
+ else:
45
+ loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2
46
+
47
+ return loss
48
+
49
+ class ResidualAdapter(nn.Module):
50
+ def __init__(self, dim, bottleneck_dim=128):
51
+ super().__init__()
52
+ self.down = nn.Linear(dim, bottleneck_dim)
53
+ self.act = nn.LeakyReLU(0.2)
54
+ self.up = nn.Linear(bottleneck_dim, dim)
55
+
56
+ nn.init.kaiming_normal_(self.down.weight)
57
+ nn.init.kaiming_normal_(self.up.weight)
58
+
59
+ def forward(self, x):
60
+ return self.up(self.act(self.down(x)))
61
+
62
+
63
+ class CLIPLossACE_HGAT(nn.Module):
64
+ def __init__(self, args, logit_scale, in_channels):
65
+ super(CLIPLossACE_HGAT, self).__init__()
66
+ set_random_seed(args.seed)
67
+ self.args = args
68
+ self.img_edge_adapter = ResidualAdapter(in_channels, args.hidden_features)
69
+ self.text_edge_adapter = ResidualAdapter(in_channels, args.hidden_features)
70
+ self.img_node_adapter = ResidualAdapter(in_channels, args.hidden_features)
71
+ self.text_node_adapter = ResidualAdapter(in_channels, args.hidden_features)
72
+
73
+ if args.learnable_logit_scale:
74
+ self.logit_scale = nn.Parameter(logit_scale.clone().detach())
75
+ else:
76
+ self.register_buffer('logit_scale', logit_scale.clone().detach())
77
+
78
+ def apply_ace_hgat(self, features, attn_weights, encoder="img"):
79
+
80
+ if encoder =="img":
81
+ edge_adapter = self.img_edge_adapter
82
+ node_adapter = self.img_node_adapter
83
+ elif encoder == 'text':
84
+ edge_adapter = self.text_edge_adapter
85
+ node_adapter = self.text_node_adapter
86
+ else:
87
+ raise ValueError(f"encoder must be img or text but given {encoder}")
88
+
89
+ B, N, D = features.shape
90
+ patches_norm = F.normalize(features[:, 1:, :], p=2, dim=-1)
91
+ # Similarity Matrix: (B, P, P)
92
+ sim = torch.zeros(size=(B, N, N), device=features.device)
93
+ patch_sim = torch.bmm(patches_norm, patches_norm.transpose(1, 2)) # [B, P, P]
94
+ sim[:, 1:, 1:] = patch_sim
95
+ sim[:, 0, 1:] = attn_weights
96
+ mask_logic = torch.eye(N, device=features.device).bool().unsqueeze(0).repeat(B, 1, 1)
97
+ mask_logic[:, 1:, 0] = True
98
+ sim = sim.masked_fill(mask_logic, -float('inf'))
99
+ topk_vals, topk_indices = torch.topk(sim, k=self.args.topk, dim=-1)
100
+ mask_sparse = torch.full_like(sim, -float('inf'))
101
+ mask_sparse.scatter_(-1, topk_indices, topk_vals)
102
+ A = F.softmax(mask_sparse, dim=-1)
103
+ A = A.masked_fill(torch.eye(N, device=features.device).bool().unsqueeze(0).repeat(B, 1, 1), 1)
104
+ A[:, 1:, 0] = A[:, 0, 1:]
105
+
106
+ H_edges_raw = torch.matmul(A, features)
107
+ H_edges_refined = edge_adapter(H_edges_raw)
108
+ H_context_raw = torch.matmul(A.transpose(1, 2), H_edges_refined)
109
+ H_context_processed = node_adapter(H_context_raw)
110
+ x_out = H_context_processed
111
+
112
+ return x_out
113
+
114
+ def forward(self, clip_model, images, texts, merged_df=None, indices=None):
115
+
116
+ device = images.device
117
+ clip_model.visual.trunk.global_pool = ''
118
+ image_features, img_attn_scores = clip_model.visual.trunk.get_attn_scores(images)
119
+ image_features = F.normalize(clip_model.visual.head(image_features), dim=-1)
120
+ text_features, text_attn_scores = clip_model.encode_text(texts, normalize=True, output_attentions=True, output_tokens=True)
121
+ img_attn_scores = img_attn_scores.mean(dim=1) # [B, 197, 197]
122
+ img_attn_weights = img_attn_scores[:, 0, 1:] # relationship between CLS token and patch embeddings [B, 196]
123
+
124
+ text_attn_scores = text_attn_scores[-1].mean(dim=1) # [B, 256, 256]
125
+ text_attn_weights = text_attn_scores[:, 0, 1:] # relationship between global token and other token embeddings [B, 255]
126
+
127
+
128
+ if self.args.apply_gnn_encoders == 'vision':
129
+ image_features = self.apply_ace_hgat(image_features, img_attn_weights, encoder="img")
130
+ image_features = F.normalize(image_features, dim=-1)
131
+
132
+ logits_per_image = self.logit_scale * image_features[:, 0] @ text_features[:, 0].t()
133
+ logits_per_text = logits_per_image.T
134
+
135
+ elif self.args.apply_gnn_encoders == 'text':
136
+ text_features = self.apply_ace_hgat(text_features, text_attn_weights, encoder="text")
137
+ text_features = F.normalize(text_features, dim=-1)
138
+
139
+ logits_per_image = self.logit_scale * image_features[:, 0] @ text_features[:, 0].t()
140
+ logits_per_text = logits_per_image.T
141
+
142
+ elif self.args.apply_gnn_encoders == 'both':
143
+ image_features = self.apply_ace_hgat(image_features, img_attn_weights, encoder="img")
144
+ image_features = F.normalize(image_features, dim=-1)
145
+
146
+ text_features = self.apply_ace_hgat(text_features, text_attn_weights, encoder="text")
147
+ text_features = F.normalize(text_features, dim=-1)
148
+
149
+ logits_per_image = self.logit_scale * image_features[:, 0] @ text_features[:, 0].t()
150
+ logits_per_text = logits_per_image.T
151
+
152
+ labels = torch.arange(image_features.shape[0], device=device, dtype=torch.long)
153
+
154
+ if logits_per_image.isnan().sum() > 0:
155
+ raise ValueError('NaN value in logits_per_image')
156
+
157
+ if merged_df is not None: # Label-Guided InfoNCE loss
158
+ compare_matrix = merged_df.iloc[indices, 2:].to_numpy()
159
+ vector_similarity_matrix = np.ones((compare_matrix.shape[0], compare_matrix.shape[0]), dtype=np.int32)
160
+ comparison = (compare_matrix[:, None, :] == compare_matrix[None, :, :]).all(axis=2)
161
+ vector_similarity_matrix[comparison] = 0
162
+ np.fill_diagonal(vector_similarity_matrix, 1)
163
+ vector_similarity_matrix = torch.from_numpy(vector_similarity_matrix).bool().to(device)
164
+ masked_logits_per_image = logits_per_image.masked_fill(~vector_similarity_matrix, float('-inf'))
165
+ masked_logits_per_text = logits_per_text.masked_fill(~vector_similarity_matrix.T, float('-inf'))
166
+ loss = (F.cross_entropy(masked_logits_per_image, labels) + F.cross_entropy(masked_logits_per_text, labels)) / 2
167
+ else:
168
+ loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2
169
+
170
+ return loss
open_clip_patch.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.bert import modeling_bert
2
+ from open_clip import CustomTextCLIP
3
+ from open_clip.hf_model import HFTextEncoder
4
+ import torch.nn.functional as F
5
+ from torch import TensorType
6
+
7
+ def patch_encode_text():
8
+
9
+ def encode_text_patched(self, text, normalize: bool = False, output_attentions = False, output_tokens = False):
10
+ if output_attentions:
11
+ features, attn_scores = self.text(text, output_attentions = output_attentions, output_tokens = output_tokens)
12
+ features = F.normalize(features, dim=-1) if normalize else features
13
+ return features, attn_scores
14
+ else:
15
+ features = self.text(text, output_attentions = output_attentions, output_tokens = output_tokens)
16
+ return F.normalize(features, dim=-1) if normalize else features
17
+
18
+ def HFText_encoder_patched(self, x: TensorType, output_attentions=False, output_tokens=False):
19
+ self.output_tokens = output_tokens
20
+ attn_mask = (x != self.config.pad_token_id).long()
21
+ out = self.transformer(input_ids=x, attention_mask=attn_mask, output_attentions=output_attentions)
22
+ if self.output_tokens:
23
+ tokens = self.proj(out[0])
24
+ if output_attentions:
25
+ return tokens, out[1]
26
+ else:
27
+ return tokens
28
+ else:
29
+ pooled_out = self.pooler(out, attn_mask)
30
+ projected = self.proj(pooled_out)
31
+
32
+ return projected
33
+
34
+ CustomTextCLIP.encode_text = encode_text_patched
35
+ HFTextEncoder.forward = HFText_encoder_patched
36
+
37
+
prompt_templates.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ prompt_templates = [
2
+ lambda c: f'a chest X-ray image of {c}.',
3
+ lambda c: f'Findings suggesting {c}.',
4
+ ]
timm_vit_return_attn_patch.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import timm.models.vision_transformer as vit
3
+
4
+ def patch_timm_vit_return_attn_scores():
5
+ _orig_attn_forward = vit.Attention.forward
6
+
7
+ def attn_forward_patched(self, x, return_attn_scores = False):
8
+ if not return_attn_scores:
9
+ return _orig_attn_forward(self, x)
10
+
11
+ B, N, C = x.shape
12
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
13
+ q, k, v = qkv.unbind(0)
14
+ q, k = self.q_norm(q), self.k_norm(k)
15
+
16
+ q = q * self.scale
17
+ attn_scores = q @ k.transpose(-2, -1)
18
+ attn = attn_scores.softmax(dim=-1)
19
+ attn = self.attn_drop(attn)
20
+ x = attn @ v
21
+ x = x.transpose(1, 2).reshape(B, N, C)
22
+ x = self.proj(x)
23
+ x = self.proj_drop(x)
24
+
25
+ return (x, attn_scores)
26
+
27
+ vit.Attention.forward = attn_forward_patched
28
+
29
+ # Patch Block.forward
30
+ _orig_block_forward = vit.Block.forward
31
+
32
+ def block_forward_patched(self, x, return_attn_scores= False):
33
+ if not return_attn_scores:
34
+ return _orig_block_forward(self, x)
35
+
36
+ out, attn_scores = self.attn(self.norm1(x), return_attn_scores=True)
37
+ x = x + self.drop_path1(self.ls1(out))
38
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
39
+ return (x, attn_scores)
40
+
41
+ vit.Block.forward = block_forward_patched
42
+
43
+ def get_attn_scores(self, x, pre_logits: bool = False):
44
+ x = self.patch_embed(x)
45
+ x = self._pos_embed(x)
46
+ x = self.patch_drop(x)
47
+ x = self.norm_pre(x)
48
+ depth = len(self.blocks)
49
+ for i, blk in enumerate(self.blocks):
50
+ if i == (depth - 1):
51
+ x, attn_scores = blk(x, return_attn_scores=True)
52
+ else:
53
+ x = blk(x)
54
+ x = self.norm(x)
55
+ if self.global_pool:
56
+ x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
57
+ x = self.fc_norm(x)
58
+ x = self.head_drop(x)
59
+
60
+ if not pre_logits:
61
+ x = self.head(x)
62
+
63
+ return (x, attn_scores)
64
+
65
+ vit.VisionTransformer.get_attn_scores = get_attn_scores